• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python variables.create_global_step函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.contrib.framework.python.ops.variables.create_global_step函数的典型用法代码示例。如果您正苦于以下问题:Python create_global_step函数的具体用法?Python create_global_step怎么用?Python create_global_step使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了create_global_step函数的17个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testLinearRegression

  def testLinearRegression(self):
    my_seed = 42
    config = run_config.RunConfig(tf_random_seed=my_seed)
    boston = base.load_boston()
    columns = [feature_column.real_valued_column('', dimension=13)]

    # We train with

    with ops.Graph().as_default() as g1:
      random.seed(my_seed)
      g1.seed = my_seed
      variables.create_global_step()
      regressor1 = linear.LinearRegressor(
          optimizer=_NULL_OPTIMIZER, feature_columns=columns, config=config)
      regressor1.fit(x=boston.data, y=boston.target, steps=1)

    with ops.Graph().as_default() as g2:
      random.seed(my_seed)
      g2.seed = my_seed
      variables.create_global_step()
      regressor2 = linear.LinearRegressor(
          optimizer=_NULL_OPTIMIZER, feature_columns=columns, config=config)
      regressor2.fit(x=boston.data, y=boston.target, steps=1)

    self.assertAllClose(regressor1.weights_, regressor2.weights_)
    self.assertAllClose(regressor1.bias_, regressor2.bias_)
    self.assertAllClose(
        list(regressor1.predict_scores(
            boston.data, as_iterable=True)),
        list(regressor2.predict_scores(
            boston.data, as_iterable=True)),
        atol=1e-05)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:32,代码来源:stability_test.py


示例2: export_estimator

def export_estimator(estimator, export_dir, input_fn=_default_input_fn,
                     signature_fn=_generic_signature_fn, default_batch_size=1,
                     exports_to_keep=None):
  """Exports inference graph into given dir.

  Args:
    estimator: Estimator to export
    export_dir: A string containing a directory to write the exported graph
      and checkpoints.
    input_fn: Function that given `Tensor` of `Example` strings, parses it into
      features that are then passed to the model.
    signature_fn: Function that given `Tensor` of `Example` strings,
      `dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions
      and returns default and named exporting signautres.
    default_batch_size: Default batch size of the `Example` placeholder.
    exports_to_keep: Number of exports to keep.
  """
  checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
  with ops.Graph().as_default() as g:
    contrib_variables.create_global_step(g)
    examples = array_ops.placeholder(dtype=dtypes.string,
                                     shape=[default_batch_size],
                                     name='input_example_tensor')
    features = input_fn(estimator, examples)
    predictions = estimator._get_predict_ops(features)
    default_signature, named_graph_signatures = signature_fn(
        examples, features, predictions)
    if exports_to_keep is not None:
      exports_to_keep = gc.largest_export_versions(exports_to_keep)
    _export_graph(g, _get_saver(), checkpoint_path, export_dir,
                  default_graph_signature=default_signature,
                  named_graph_signatures=named_graph_signatures,
                  exports_to_keep=exports_to_keep)
开发者ID:363158858,项目名称:tensorflow,代码行数:33,代码来源:export.py


示例3: _export_estimator

def _export_estimator(estimator,
                      export_dir,
                      signature_fn,
                      input_fn,
                      default_batch_size,
                      exports_to_keep):
  input_fn = input_fn or _default_input_fn
  checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
  with ops.Graph().as_default() as g:
    contrib_variables.create_global_step(g)
    examples = array_ops.placeholder(dtype=dtypes.string,
                                     shape=[default_batch_size],
                                     name='input_example_tensor')
    features = input_fn(estimator, examples)
    predictions = estimator._get_predict_ops(features)

    # Explicit signature_fn takes priority
    if signature_fn:
      default_signature, named_graph_signatures = signature_fn(examples,
                                                               features,
                                                               predictions)
    else:
      try:
        # Some estimators provide a target_column of known type
        target_column = estimator._get_target_column()
        problem_type = target_column.problem_type

        if problem_type == layers.ProblemType.CLASSIFICATION:
          signature_fn = classification_signature_fn
        elif problem_type == layers.ProblemType.LINEAR_REGRESSION:
          signature_fn = regression_signature_fn
        elif problem_type == layers.ProblemType.LOGISTIC_REGRESSION:
          signature_fn = logistic_regression_signature_fn
        else:
          raise ValueError(
              'signature_fn must be provided because the TargetColumn is a %s, '
              'which does not have a standard problem type and so cannot use a '
              'standard export signature.' % type(target_column).__name__)

        default_signature, named_graph_signatures = (
            signature_fn(examples, features, predictions))
      except AttributeError:
        logging.warn(
            'Change warning: `signature_fn` will be required after'
            '2016-08-01.\n'
            'Using generic signatures for now.  To maintain this behavior, '
            'pass:\n'
            '  signature_fn=export.generic_signature_fn\n'
            'Also consider passing a regression or classification signature; '
            'see cl/126430915 for an example.')
        default_signature, named_graph_signatures = generic_signature_fn(
            examples, features, predictions)
    if exports_to_keep is not None:
      exports_to_keep = gc.largest_export_versions(exports_to_keep)
    _export_graph(g, _get_saver(), checkpoint_path, export_dir,
                  default_graph_signature=default_signature,
                  named_graph_signatures=named_graph_signatures,
                  exports_to_keep=exports_to_keep)
开发者ID:JamesFysh,项目名称:tensorflow,代码行数:58,代码来源:export.py


示例4: test_create_global_step

 def test_create_global_step(self):
   self.assertEquals(None, variables_lib2.get_global_step())
   with ops.Graph().as_default() as g:
     global_step = variables_lib2.create_global_step()
     self._assert_global_step(global_step)
     self.assertRaisesRegexp(ValueError, 'already exists',
                             variables_lib2.create_global_step)
     self.assertRaisesRegexp(ValueError, 'already exists',
                             variables_lib2.create_global_step, g)
     self._assert_global_step(variables_lib2.create_global_step(ops.Graph()))
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:10,代码来源:variables_test.py


示例5: setUp

  def setUp(self):
    test.TestCase.setUp(self)

    self.log_dir = 'log/dir'
    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)

    var = variable_scope.get_variable('var', initializer=0.0, use_resource=True)
    tensor = state_ops.assign_add(var, 1.0)
    self.summary_op = summary_lib.scalar('my_summary', tensor)

    with variable_scope.variable_scope('foo', use_resource=True):
      variables.create_global_step()
    self.train_op = training_util._increment_global_step(1)
开发者ID:Mazecreator,项目名称:tensorflow,代码行数:13,代码来源:basic_session_run_hooks_test.py


示例6: export_estimator

def export_estimator(estimator,
                     export_dir,
                     signature_fn=None,
                     input_fn=_default_input_fn,
                     default_batch_size=1,
                     exports_to_keep=None):
  """Exports inference graph into given dir.

  Args:
    estimator: Estimator to export
    export_dir: A string containing a directory to write the exported graph
      and checkpoints.
    signature_fn: Function that given `Tensor` of `Example` strings,
      `dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions
    input_fn: Function that given `Tensor` of `Example` strings, parses it into
      features that are then passed to the model.
      and returns default and named exporting signatures.
    default_batch_size: Default batch size of the `Example` placeholder.
    exports_to_keep: Number of exports to keep.
  """
  checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
  with ops.Graph().as_default() as g:
    contrib_variables.create_global_step(g)
    examples = array_ops.placeholder(dtype=dtypes.string,
                                     shape=[default_batch_size],
                                     name='input_example_tensor')
    features = input_fn(estimator, examples)
    predictions = estimator._get_predict_ops(features)
    if signature_fn:
      default_signature, named_graph_signatures = signature_fn(examples,
                                                               features,
                                                               predictions)
    else:
      logging.warn(
          'Change warning: `signature_fn` will be required after 2016-08-01.\n'
          'Using generic signatures for now.  To maintain this behavior, '
          'pass:\n'
          '  signature_fn=export.generic_signature_fn\n'
          'Also consider passing a regression or classification signature; see '
          'cl/126430915 for an example.')
      default_signature, named_graph_signatures = generic_signature_fn(
          examples, features, predictions)
    if exports_to_keep is not None:
      exports_to_keep = gc.largest_export_versions(exports_to_keep)
    _export_graph(g, _get_saver(), checkpoint_path, export_dir,
                  default_graph_signature=default_signature,
                  named_graph_signatures=named_graph_signatures,
                  exports_to_keep=exports_to_keep)
开发者ID:10imaging,项目名称:tensorflow,代码行数:48,代码来源:export.py


示例7: test_evaluate_ready_for_local_init

 def test_evaluate_ready_for_local_init(self):
   with ops.Graph().as_default() as g, self.test_session(g):
     variables_lib.create_global_step()
     v = variables.Variable(1.0)
     w = variables.Variable(
         v + 1, collections=[ops.GraphKeys.LOCAL_VARIABLES], trainable=False)
     ready_for_local_init_op = variables.report_uninitialized_variables(
         variables.global_variables())
     ops.add_to_collection(ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
                           ready_for_local_init_op)
     _ = learn.graph_actions.evaluate(
         g,
         output_dir=self._output_dir,
         checkpoint_path=None,
         eval_dict={'a': v},
         max_steps=1)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:16,代码来源:graph_actions_test.py


示例8: test_train_worker_monitor

 def test_train_worker_monitor(self):
   # We need to explicitly set device due to check on non-chief workers
   # requiring all variables to have a device assigned.
   with ops.Graph().as_default() as g, g.device('/cpu:0'):
     global_step = variables_lib.create_global_step(g)
     train_op = state_ops.assign_add(global_step, 1)
     loss_op = constant_op.constant(2.0)
     summary.scalar('loss', loss_op)
     # Add explicit "local" init op to initialize all variables
     # as there's no chief to init here.
     init_op = variables.global_variables_initializer()
     ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, init_op)
     # Create worker monitors where one should be active on the worker
     # and the other chief exclusive.
     chief_exclusive_monitor = _BaseMonitorWrapper(False)
     all_workers_monitor = _BaseMonitorWrapper(True)
     with self.test_session(g):
       loss = learn.graph_actions.train(
           g,
           output_dir=self._output_dir,
           global_step_tensor=global_step,
           train_op=train_op,
           loss_op=loss_op,
           supervisor_is_chief=False,
           steps=1,
           monitors=[chief_exclusive_monitor, all_workers_monitor])
     self.assertEqual(2.0, loss)
     self.assertTrue(not chief_exclusive_monitor.is_active and
                     all_workers_monitor.is_active,
                     'Only non-chief runnable monitor must have been active.')
     self.assertTrue(not chief_exclusive_monitor.has_step and
                     all_workers_monitor.has_step,
                     'Only non-chief runnable monitor must have a step.')
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:33,代码来源:graph_actions_test.py


示例9: testDNNRegression

  def testDNNRegression(self):
    my_seed = 42
    config = run_config.RunConfig(tf_random_seed=my_seed)
    boston = base.load_boston()
    columns = [feature_column.real_valued_column('', dimension=13)]

    with ops.Graph().as_default() as g1:
      random.seed(my_seed)
      g1.seed = my_seed
      variables.create_global_step()
      regressor1 = dnn.DNNRegressor(
          hidden_units=[10],
          feature_columns=columns,
          optimizer=_NULL_OPTIMIZER,
          config=config)
      regressor1.fit(x=boston.data, y=boston.target, steps=1)

    with ops.Graph().as_default() as g2:
      random.seed(my_seed)
      g2.seed = my_seed
      variables.create_global_step()
      regressor2 = dnn.DNNRegressor(
          hidden_units=[10],
          feature_columns=columns,
          optimizer=_NULL_OPTIMIZER,
          config=config)
      regressor2.fit(x=boston.data, y=boston.target, steps=1)

    weights1 = ([regressor1.get_variable_value('dnn/hiddenlayer_0/weights')] +
                [regressor1.get_variable_value('dnn/logits/weights')])
    weights2 = ([regressor2.get_variable_value('dnn/hiddenlayer_0/weights')] +
                [regressor2.get_variable_value('dnn/logits/weights')])
    for w1, w2 in zip(weights1, weights2):
      self.assertAllClose(w1, w2)

    biases1 = ([regressor1.get_variable_value('dnn/hiddenlayer_0/biases')] +
               [regressor1.get_variable_value('dnn/logits/biases')])
    biases2 = ([regressor2.get_variable_value('dnn/hiddenlayer_0/biases')] +
               [regressor2.get_variable_value('dnn/logits/biases')])
    for b1, b2 in zip(biases1, biases2):
      self.assertAllClose(b1, b2)
    self.assertAllClose(
        list(regressor1.predict_scores(
            boston.data, as_iterable=True)),
        list(regressor2.predict_scores(
            boston.data, as_iterable=True)),
        atol=1e-05)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:47,代码来源:stability_test.py


示例10: _build_inference_graph

  def _build_inference_graph(self):
    """Build simple inference graph.

    This includes a regular variable, local variable, and fake table.

    Returns:
      Tuple of 3 `Tensor` objects, 2 input and 1 output.
    """
    variables_lib.create_global_step()
    in0 = variables.Variable(1.0)
    in1 = variables_lib.local_variable(2.0)
    fake_table = variables.Variable(
        3.0,
        trainable=False,
        collections=['fake_tables'],
        name='fake_table_var')
    in0.graph.add_to_collections([ops.GraphKeys.TABLE_INITIALIZERS],
                                 fake_table.initializer)
    out = in0 + in1 + fake_table
    return in0, in1, out
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:20,代码来源:graph_actions_test.py


示例11: test_train_loss

 def test_train_loss(self):
   with ops.Graph().as_default() as g, self.test_session(g):
     variables_lib.create_global_step()
     loss_var = variables_lib.local_variable(10.0)
     train_op = control_flow_ops.group(
         state_ops.assign_add(variables_lib.get_global_step(), 1),
         state_ops.assign_add(loss_var, -1.0))
     self._assert_summaries(self._output_dir)
     self._assert_ckpt(self._output_dir, False)
     loss = learn.graph_actions.train(
         g,
         output_dir=self._output_dir,
         train_op=train_op,
         loss_op=loss_var.value(),
         steps=6)
     # TODO(ebrevdo,ptucker,ispir): this meta_graph_def lacks the
     # SaverDef, so we can't add it to the summary assertion test below.
     # meta_graph_def = meta_graph.create_meta_graph_def()
     self.assertEqual(4.0, loss)
     self._assert_summaries(self._output_dir, expected_graphs=[g])
     self._assert_ckpt(self._output_dir, True)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:21,代码来源:graph_actions_test.py


示例12: test_requests

  def test_requests(self):
    with ops.Graph().as_default(), session_lib.Session() as sess:
      variables_lib.create_global_step()
      mock_mon = FakeMonitor()
      mock_mon2 = FakeMonitor()

      hook = learn.monitors.RunHookAdapterForMonitors([mock_mon, mock_mon2])
      hook.begin()

      mon_sess = monitored_session._HookedSession(sess=sess, hooks=[hook])

      a_tensor = constant_op.constant([0], name='a_tensor')
      constant_op.constant([5], name='another_tensor')
      constant_op.constant([10], name='third_tensor')
      mock_mon.requested_tensors = ['another_tensor']
      mock_mon2.requested_tensors = ['third_tensor']
      sess.run(variables.global_variables_initializer())

      output = mon_sess.run(a_tensor)
      self.assertEqual(output, [0])
      self.assertEqual(mock_mon.output['another_tensor'], [5])
      self.assertEqual(mock_mon2.output['third_tensor'], [10])
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:22,代码来源:monitors_test.py


示例13: test_calls_and_steps

  def test_calls_and_steps(self):
    with ops.Graph().as_default(), session_lib.Session() as sess:
      global_step_tensor = variables_lib.create_global_step()
      inc_5 = state_ops.assign_add(global_step_tensor, 5)
      mock_mon = FakeMonitor()
      mock_mon2 = FakeMonitor()

      hook = learn.monitors.RunHookAdapterForMonitors([mock_mon, mock_mon2])
      hook.begin()
      for mon in [mock_mon, mock_mon2]:
        self.assertEqual(mon.call_counter['begin'], 1)

      sess.run(variables.global_variables_initializer())
      sess.run(global_step_tensor.assign(10))

      mon_sess = monitored_session._HookedSession(sess=sess, hooks=[hook])

      mon_sess.run(inc_5)
      for mon in [mock_mon, mock_mon2]:
        self.assertEqual(mon.output, {})
        self.assertEqual(mon.last_begin_step, 11)
        self.assertEqual(mon.last_end_step, 11)
        self.assertEqual(mon.last_post_step, 11)
        self.assertEqual(mon.call_counter['step_end'], 1)
        self.assertEqual(mon.call_counter['step_begin'], 1)
        self.assertEqual(mon.call_counter['post_step'], 1)

      mon_sess.run(inc_5)
      for mon in [mock_mon, mock_mon2]:
        self.assertEqual(mon.output, {})
        self.assertEqual(mon.last_begin_step, 16)
        self.assertEqual(mon.last_end_step, 16)
        self.assertEqual(mon.last_post_step, 16)
        self.assertEqual(mon.call_counter['step_end'], 2)
        self.assertEqual(mon.call_counter['step_begin'], 2)
        self.assertEqual(mon.call_counter['post_step'], 2)

      hook.end(sess)
      for mon in [mock_mon, mock_mon2]:
        self.assertEqual(mon.call_counter['end'], 1)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:40,代码来源:monitors_test.py


示例14: _export_estimator

def _export_estimator(estimator,
                      export_dir,
                      signature_fn,
                      input_fn,
                      default_batch_size,
                      exports_to_keep,
                      input_feature_key=None,
                      use_deprecated_input_fn=True,
                      prediction_key=None,
                      checkpoint_path=None):
  if use_deprecated_input_fn:
    input_fn = input_fn or _default_input_fn
  elif input_fn is None:
    raise ValueError('input_fn must be defined.')

  # If checkpoint_path is specified, use the specified checkpoint path.
  checkpoint_path = (checkpoint_path or
                     tf_saver.latest_checkpoint(estimator._model_dir))
  with ops.Graph().as_default() as g:
    contrib_variables.create_global_step(g)

    if use_deprecated_input_fn:
      examples = array_ops.placeholder(dtype=dtypes.string,
                                       shape=[default_batch_size],
                                       name='input_example_tensor')
      features = input_fn(estimator, examples)
    else:
      features, _ = input_fn()
      examples = None
      if input_feature_key is not None:
        examples = features.pop(input_feature_key)

    if (not features) and (examples is None):
      raise ValueError('Either features or examples must be defined.')

    predictions = estimator._get_predict_ops(features).predictions

    if prediction_key is not None:
      predictions = predictions[prediction_key]

    # Explicit signature_fn takes priority
    if signature_fn:
      default_signature, named_graph_signatures = signature_fn(examples,
                                                               features,
                                                               predictions)
    else:
      try:
        # Some estimators provide a signature function.
        # TODO(zakaria): check if the estimator has this function,
        #   raise helpful error if not
        signature_fn = estimator._create_signature_fn()

        default_signature, named_graph_signatures = (
            signature_fn(examples, features, predictions))
      except AttributeError:
        logging.warn(
            'Change warning: `signature_fn` will be required after'
            '2016-08-01.\n'
            'Using generic signatures for now.  To maintain this behavior, '
            'pass:\n'
            '  signature_fn=export.generic_signature_fn\n'
            'Also consider passing a regression or classification signature; '
            'see cl/126430915 for an example.')
        default_signature, named_graph_signatures = generic_signature_fn(
            examples, features, predictions)
    if exports_to_keep is not None:
      exports_to_keep = gc.largest_export_versions(exports_to_keep)
    return _export_graph(
        g,
        _get_saver(),
        checkpoint_path,
        export_dir,
        default_graph_signature=default_signature,
        named_graph_signatures=named_graph_signatures,
        exports_to_keep=exports_to_keep)
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:75,代码来源:export.py


示例15: export_fn

  def export_fn(estimator, export_dir_base, checkpoint_path=None, eval_result=None):
    with ops.Graph().as_default() as g:
      contrib_variables.create_global_step(g)

      input_ops = feature_transforms.build_csv_serving_tensors_for_training_step(
          args.analysis, features, schema, stats, keep_target)
      model_fn_ops = estimator._call_model_fn(input_ops.features,
                                              None,
                                              model_fn_lib.ModeKeys.INFER)
      output_fetch_tensors = make_prediction_output_tensors(
          args=args,
          features=features,
          input_ops=input_ops,
          model_fn_ops=model_fn_ops,
          keep_target=keep_target)

      # Don't use signature_def_utils.predict_signature_def as that renames
      # tensor names if there is only 1 input/output tensor!
      signature_inputs = {key: tf.saved_model.utils.build_tensor_info(tensor)
                          for key, tensor in six.iteritems(input_ops.default_inputs)}
      signature_outputs = {key: tf.saved_model.utils.build_tensor_info(tensor)
                           for key, tensor in six.iteritems(output_fetch_tensors)}
      signature_def_map = {
          'serving_default':
              signature_def_utils.build_signature_def(
                  signature_inputs,
                  signature_outputs,
                  tf.saved_model.signature_constants.PREDICT_METHOD_NAME)}

      if not checkpoint_path:
        # Locate the latest checkpoint
        checkpoint_path = saver.latest_checkpoint(estimator._model_dir)
      if not checkpoint_path:
        raise ValueError("Couldn't find trained model at %s."
                         % estimator._model_dir)

      export_dir = saved_model_export_utils.get_timestamped_export_dir(
          export_dir_base)

      if (model_fn_ops.scaffold is not None and
         model_fn_ops.scaffold.saver is not None):
        saver_for_restore = model_fn_ops.scaffold.saver
      else:
        saver_for_restore = saver.Saver(sharded=True)

      with tf_session.Session('') as session:
        saver_for_restore.restore(session, checkpoint_path)
        init_op = control_flow_ops.group(
            variables.local_variables_initializer(),
            resources.initialize_resources(resources.shared_resources()),
            tf.tables_initializer())

        # Perform the export
        builder = saved_model_builder.SavedModelBuilder(export_dir)
        builder.add_meta_graph_and_variables(
            session, [tag_constants.SERVING],
            signature_def_map=signature_def_map,
            assets_collection=ops.get_collection(
                ops.GraphKeys.ASSET_FILEPATHS),
            legacy_init_op=init_op)
        builder.save(False)

      # Add the extra assets
      if assets_extra:
        assets_extra_path = os.path.join(compat.as_bytes(export_dir),
                                         compat.as_bytes('assets.extra'))
        for dest_relative, source in assets_extra.items():
          dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
                                       compat.as_bytes(dest_relative))
          dest_path = os.path.dirname(dest_absolute)
          file_io.recursive_create_dir(dest_path)
          file_io.copy(source, dest_absolute)

    # only keep the last 3 models
    saved_model_export_utils.garbage_collect_exports(
        export_dir_base,
        exports_to_keep=3)

    # save the last model to the model folder.
    # export_dir_base = A/B/intermediate_models/
    if keep_target:
      final_dir = os.path.join(args.job_dir, 'evaluation_model')
    else:
      final_dir = os.path.join(args.job_dir, 'model')
    if file_io.is_directory(final_dir):
      file_io.delete_recursively(final_dir)
    file_io.recursive_create_dir(final_dir)
    recursive_copy(export_dir, final_dir)

    return export_dir
开发者ID:javiervicho,项目名称:pydatalab,代码行数:90,代码来源:task.py


示例16: export_fn

  def export_fn(estimator, export_dir_base, checkpoint_path=None, eval_result=None):
    with ops.Graph().as_default() as g:
      contrib_variables.create_global_step(g)

      input_ops = serving_from_csv_input(train_config, args, keep_target)
      model_fn_ops = estimator._call_model_fn(input_ops.features,
                                              None,
                                              model_fn_lib.ModeKeys.INFER)
      output_fetch_tensors = make_output_tensors(
          train_config=train_config,
          args=args,
          input_ops=input_ops,
          model_fn_ops=model_fn_ops,
          keep_target=keep_target)

      signature_def_map = {
        'serving_default': signature_def_utils.predict_signature_def(input_ops.default_inputs,
                                                                     output_fetch_tensors)
      }

      if not checkpoint_path:
        # Locate the latest checkpoint
        checkpoint_path = saver.latest_checkpoint(estimator._model_dir)
      if not checkpoint_path:
        raise NotFittedError("Couldn't find trained model at %s."
                             % estimator._model_dir)

      export_dir = saved_model_export_utils.get_timestamped_export_dir(
          export_dir_base)

      with tf_session.Session('') as session:
        # variables.initialize_local_variables()
        variables.local_variables_initializer()
        data_flow_ops.tables_initializer()
        saver_for_restore = saver.Saver(
            variables.global_variables(),
            sharded=True)
        saver_for_restore.restore(session, checkpoint_path)

        init_op = control_flow_ops.group(
            variables.local_variables_initializer(),
            data_flow_ops.tables_initializer())

        # Perform the export
        builder = saved_model_builder.SavedModelBuilder(export_dir)
        builder.add_meta_graph_and_variables(
            session, [tag_constants.SERVING],
            signature_def_map=signature_def_map,
            assets_collection=ops.get_collection(
                ops.GraphKeys.ASSET_FILEPATHS),
            legacy_init_op=init_op)
        builder.save(False)

      # Add the extra assets
      if assets_extra:
        assets_extra_path = os.path.join(compat.as_bytes(export_dir),
                                         compat.as_bytes('assets.extra'))
        for dest_relative, source in assets_extra.items():
          dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
                                       compat.as_bytes(dest_relative))
          dest_path = os.path.dirname(dest_absolute)
          gfile.MakeDirs(dest_path)
          gfile.Copy(source, dest_absolute)

    # only keep the last 3 models
    saved_model_export_utils.garbage_collect_exports(
        python_portable_string(export_dir_base),
        exports_to_keep=3)

    # save the last model to the model folder.
    # export_dir_base = A/B/intermediate_models/
    if keep_target:
      final_dir = os.path.join(args.job_dir, 'evaluation_model')
    else:
      final_dir = os.path.join(args.job_dir, 'model')
    if file_io.is_directory(final_dir):
      file_io.delete_recursively(final_dir)
    file_io.recursive_create_dir(final_dir)
    _recursive_copy(export_dir, final_dir)

    return export_dir
开发者ID:parthea,项目名称:pydatalab,代码行数:81,代码来源:util.py


示例17: _export_estimator

def _export_estimator(estimator,
                      export_dir,
                      signature_fn,
                      input_fn,
                      default_batch_size,
                      exports_to_keep,
                      input_feature_key=None,
                      use_deprecated_input_fn=True,
                      prediction_key=None):
  if use_deprecated_input_fn:
    input_fn = input_fn or _default_input_fn
  elif input_fn is None:
    raise ValueError('input_fn must be defined.')

  checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
  with ops.Graph().as_default() as g:
    contrib_variables.create_global_step(g)

    if use_deprecated_input_fn:
      examples = array_ops.placeholder(dtype=dtypes.string,
                                       shape=[default_batch_size],
                                       name='input_example_tensor')
      features = input_fn(estimator, examples)
    else:
      features, _ = input_fn()
      examples = None
      if input_feature_key is not None:
        examples = features.pop(input_feature_key)

    if (not features) and (examples is None):
      raise ValueError('Either features or examples must be defined.')

    # The default return type of _get_predict_ops is ModelFnOps. But there are
    # some subclasses of tf.contrib.learn.Estimator which override this
    # method and use the legacy signature, namely _get_predict_ops returns a
    # `predictions` Tensor or dict or Tensors. The following else-statement
    # code covers these cases, but will soon be deleted after the subclasses
    # are updated.
    # TODO(b/32664904): Update subclasses and delete the else-statement.
    infer_ops = estimator._get_predict_ops(features)
    if isinstance(infer_ops, model_fn.ModelFnOps):  # Default signature
      predictions = infer_ops.predictions
    else:  # Legacy signature
      predictions = infer_ops

    if prediction_key is not None:
      predictions = predictions[prediction_key]

    # Explicit signature_fn takes priority
    if signature_fn:
      default_signature, named_graph_signatures = signature_fn(examples,
                                                               features,
                                                               predictions)
    else:
      try:
        # Some estimators provide a signature function.
        # TODO(zakaria): check if the estimator has this function,
        #   raise helpful error if not
        signature_fn = estimator._create_signature_fn()

        default_signature, named_graph_signatures = (
            signature_fn(examples, features, predictions))
      except AttributeError:
        logging.warn(
            'Change warning: `signature_fn` will be required after'
            '2016-08-01.\n'
            'Using generic signatures for now.  To maintain this behavior, '
            'pass:\n'
            '  signature_fn=export.generic_signature_fn\n'
            'Also consider passing a regression or classification signature; '
            'see cl/126430915 for an example.')
        default_signature, named_graph_signatures = generic_signature_fn(
            examples, features, predictions)
    if exports_to_keep is not None:
      exports_to_keep = gc.largest_export_versions(exports_to_keep)
    return _export_graph(
        g,
        _get_saver(),
        checkpoint_path,
        export_dir,
        default_graph_signature=default_signature,
        named_graph_signatures=named_graph_signatures,
        exports_to_keep=exports_to_keep)
开发者ID:Y-owen,项目名称:tensorflow,代码行数:83,代码来源:export.py



注:本文中的tensorflow.contrib.framework.python.ops.variables.create_global_step函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python variables.get_global_step函数代码示例发布时间:2022-05-27
下一篇:
Python variables.assert_or_get_global_step函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap