• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python saver.latest_checkpoint函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.training.saver.latest_checkpoint函数的典型用法代码示例。如果您正苦于以下问题:Python latest_checkpoint函数的具体用法?Python latest_checkpoint怎么用?Python latest_checkpoint使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了latest_checkpoint函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testRecoverSession

  def testRecoverSession(self):
    # Create a checkpoint.
    checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
    try:
      gfile.DeleteRecursively(checkpoint_dir)
    except errors.OpError:
      pass  # Ignore
    gfile.MakeDirs(checkpoint_dir)

    with ops.Graph().as_default():
      v = variables.Variable(1, name="v")
      sm = session_manager.SessionManager(
          ready_op=variables.report_uninitialized_variables())
      saver = saver_lib.Saver({"v": v})
      sess, initialized = sm.recover_session(
          "", saver=saver, checkpoint_dir=checkpoint_dir)
      self.assertFalse(initialized)
      sess.run(v.initializer)
      self.assertEquals(1, sess.run(v))
      saver.save(sess,
                 os.path.join(checkpoint_dir, "recover_session_checkpoint"))
    self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
    self._test_recovered_variable(
        checkpoint_filename_with_path=saver_lib.latest_checkpoint(
            checkpoint_dir))
    # Cannot set both checkpoint_dir and checkpoint_filename_with_path.
    with self.assertRaises(ValueError):
      self._test_recovered_variable(
          checkpoint_dir=checkpoint_dir,
          checkpoint_filename_with_path=saver_lib.latest_checkpoint(
              checkpoint_dir))
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:31,代码来源:session_manager_test.py


示例2: _read_config_files

  def _read_config_files(self, run_paths):
    configs = {}
    config_fpaths = {}
    for run_name, logdir in run_paths.items():
      config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
      if not file_io.file_exists(config_fpath):
        # Skip runs that have no config file.
        continue
      # Read the config file.
      file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
      config = ProjectorConfig()
      text_format.Merge(file_content, config)

      if not config.model_checkpoint_path:
        # See if you can find a checkpoint file in the logdir.
        ckpt_path = latest_checkpoint(logdir)
        if not ckpt_path:
          # Or in the parent of logdir.
          ckpt_path = latest_checkpoint(os.path.join('../', logdir))
          if not ckpt_path:
            logging.warning('Cannot find model checkpoint in %s', logdir)
            continue
        config.model_checkpoint_path = ckpt_path

      # Sanity check for the checkpoint file.
      if not file_io.file_exists(config.model_checkpoint_path):
        logging.warning('Checkpoint file %s not found',
                        config.model_checkpoint_path)
        continue
      configs[run_name] = config
      config_fpaths[run_name] = config_fpath
    return configs, config_fpaths
开发者ID:KalraA,项目名称:tensorflow,代码行数:32,代码来源:plugin.py


示例3: _find_latest_checkpoint

def _find_latest_checkpoint(dir_path):
  try:
    ckpt_path = latest_checkpoint(dir_path)
    if not ckpt_path:
      # Check the parent directory.
      ckpt_path = latest_checkpoint(os.path.join(dir_path, os.pardir))
    return ckpt_path
  except errors.NotFoundError:
    return None
开发者ID:chenjun0210,项目名称:tensorflow,代码行数:9,代码来源:projector_plugin.py


示例4: testMultiEvalStepIncrements

  def testMultiEvalStepIncrements(self):
    checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops')

    # Train a model for a single step to get a checkpoint.
    self._train_model(checkpoint_dir, num_steps=1)
    checkpoint_path = saver.latest_checkpoint(checkpoint_dir)

    # Create the model so we have something to restore.
    inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    logistic_classifier(inputs)

    num_evals = 6

    my_var = local_variable(0.0, name='MyVar')
    # In eval ops, we also increase the eval step one more time.
    eval_ops = [state_ops.assign_add(my_var, 1.0),
                state_ops.assign_add(
                    evaluation._get_or_create_eval_step(), 1, use_locking=True)]
    expect_eval_update_counts = num_evals // 2

    final_ops = array_ops.identity(my_var)

    final_ops_values = evaluation._evaluate_once(
        checkpoint_path=checkpoint_path,
        eval_ops=eval_ops,
        final_ops={'value': final_ops},
        hooks=[evaluation._StopAfterNEvalsHook(num_evals),])
    self.assertEqual(final_ops_values['value'], expect_eval_update_counts)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:28,代码来源:evaluation_test.py


示例5: testEvalOpAndFinalOp

  def testEvalOpAndFinalOp(self):
    checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops')

    # Train a model for a single step to get a checkpoint.
    self._train_model(checkpoint_dir, num_steps=1)
    checkpoint_path = saver.latest_checkpoint(checkpoint_dir)

    # Create the model so we have something to restore.
    inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    logistic_classifier(inputs)

    num_evals = 5
    final_increment = 9.0

    my_var = local_variable(0.0, name='MyVar')
    eval_ops = state_ops.assign_add(my_var, 1.0)
    final_ops = array_ops.identity(my_var) + final_increment

    final_hooks = [evaluation._StopAfterNEvalsHook(num_evals),]
    initial_hooks = list(final_hooks)
    final_ops_values = evaluation._evaluate_once(
        checkpoint_path=checkpoint_path,
        eval_ops=eval_ops,
        final_ops={'value': final_ops},
        hooks=final_hooks)
    self.assertEqual(final_ops_values['value'], num_evals + final_increment)
    self.assertEqual(initial_hooks, final_hooks)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:27,代码来源:evaluation_test.py


示例6: testEvaluateWithFiniteInputs

  def testEvaluateWithFiniteInputs(self):
    checkpoint_dir = os.path.join(self.get_temp_dir(),
                                  'evaluate_with_finite_inputs')

    # Train a Model to completion:
    self._train_model(checkpoint_dir, num_steps=300)

    # Run evaluation. Inputs are fed through input producer for one epoch.
    all_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    all_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

    single_input, single_label = training.slice_input_producer(
        [all_inputs, all_labels], num_epochs=1)
    inputs, labels = training.batch([single_input, single_label], batch_size=6,
                                    allow_smaller_final_batch=True)

    logits = logistic_classifier(inputs)
    predictions = math_ops.round(logits)

    accuracy, update_op = metrics.accuracy(
        predictions=predictions, labels=labels)

    checkpoint_path = saver.latest_checkpoint(checkpoint_dir)

    final_ops_values = evaluation._evaluate_once(
        checkpoint_path=checkpoint_path,
        eval_ops=update_op,
        final_ops={'accuracy': accuracy,
                   'eval_steps': evaluation._get_or_create_eval_step()},
        hooks=[evaluation._StopAfterNEvalsHook(None),])
    self.assertTrue(final_ops_values['accuracy'] > .99)
    # Runs evaluation for 4 iterations. First 2 evaluate full batch of 6 inputs
    # each; the 3rd iter evaluates the remaining 4 inputs, and the last one
    # triggers an error which stops evaluation.
    self.assertEqual(final_ops_values['eval_steps'], 4)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:35,代码来源:evaluation_test.py


示例7: test_recovery

 def test_recovery(self):
   logdir = _test_dir(self.get_temp_dir(), 'test_recovery')
   with ops.Graph().as_default():
     gstep = variables_lib.get_or_create_global_step()
     do_step = state_ops.assign_add(gstep, 1)
     scaffold = monitored_session.Scaffold()
     # Use a hook to save the model every 100 steps.  It also saves it at
     # the end.
     hooks = [
         basic_session_run_hooks.CheckpointSaverHook(
             logdir, save_steps=1, scaffold=scaffold)
     ]
     with monitored_session.MonitoredSession(
         session_creator=monitored_session.ChiefSessionCreator(
             scaffold, checkpoint_dir=logdir),
         hooks=hooks) as session:
       self.assertEqual(0, session.run(gstep))
       self.assertEqual(1, session.run(do_step))
       self.assertEqual(2, session.run(do_step))
     # A restart will find the checkpoint and recover automatically.
     with monitored_session.MonitoredSession(
         session_creator=monitored_session.ChiefSessionCreator(
             scaffold, checkpoint_dir=logdir)) as session:
       self.assertEqual(2, session.run(gstep))
     # A restart will find the checkpoint and recover automatically.
     with monitored_session.MonitoredSession(
         session_creator=monitored_session.ChiefSessionCreator(
             scaffold,
             checkpoint_filename_with_path=saver_lib.latest_checkpoint(
                 logdir))) as session:
       self.assertEqual(2, session.run(gstep))
开发者ID:kadeng,项目名称:tensorflow,代码行数:31,代码来源:monitored_session_test.py


示例8: export_estimator

def export_estimator(estimator, export_dir, input_fn=_default_input_fn,
                     signature_fn=_generic_signature_fn, default_batch_size=1,
                     exports_to_keep=None):
  """Exports inference graph into given dir.

  Args:
    estimator: Estimator to export
    export_dir: A string containing a directory to write the exported graph
      and checkpoints.
    input_fn: Function that given `Tensor` of `Example` strings, parses it into
      features that are then passed to the model.
    signature_fn: Function that given `Tensor` of `Example` strings,
      `dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions
      and returns default and named exporting signautres.
    default_batch_size: Default batch size of the `Example` placeholder.
    exports_to_keep: Number of exports to keep.
  """
  checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
  with ops.Graph().as_default() as g:
    contrib_variables.create_global_step(g)
    examples = array_ops.placeholder(dtype=dtypes.string,
                                     shape=[default_batch_size],
                                     name='input_example_tensor')
    features = input_fn(estimator, examples)
    predictions = estimator._get_predict_ops(features)
    default_signature, named_graph_signatures = signature_fn(
        examples, features, predictions)
    if exports_to_keep is not None:
      exports_to_keep = gc.largest_export_versions(exports_to_keep)
    _export_graph(g, _get_saver(), checkpoint_path, export_dir,
                  default_graph_signature=default_signature,
                  named_graph_signatures=named_graph_signatures,
                  exports_to_keep=exports_to_keep)
开发者ID:363158858,项目名称:tensorflow,代码行数:33,代码来源:export.py


示例9: _restore_or_save_initial_ckpt

  def _restore_or_save_initial_ckpt(self, session):
    # Ideally this should be run in after_create_session but is not for the
    # following reason:
    # Currently there is no way of enforcing an order of running the
    # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
    # is run *after* this hook. That is troublesome because
    # 1. If a checkpoint exists and this hook restores it, the initializer hook
    #    will override it.
    # 2. If no checkpoint exists, this hook will try to save an initialized
    #    iterator which will result in an exception.
    #
    # As a temporary fix we enter the following implicit contract between this
    # hook and the _DatasetInitializerHook.
    # 1. The _DatasetInitializerHook initializes the iterator in the call to
    #    after_create_session.
    # 2. This hook saves the iterator on the first call to `before_run()`, which
    #    is guaranteed to happen after `after_create_session()` of all hooks
    #    have been run.

    # Check if there is an existing checkpoint. If so, restore from it.
    # pylint: disable=protected-access
    latest_checkpoint_path = saver_lib.latest_checkpoint(
        self._checkpoint_saver_hook._checkpoint_dir,
        latest_filename=self._latest_filename)
    if latest_checkpoint_path:
      self._checkpoint_saver_hook._get_saver().restore(session,
                                                       latest_checkpoint_path)
    else:
      # The checkpoint saved here is the state at step "global_step".
      # Note: We do not save the GraphDef or MetaGraphDef here.
      global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
      self._checkpoint_saver_hook._save(session, global_step)
      self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:33,代码来源:iterator_ops.py


示例10: _evaluate_model

  def _evaluate_model(self,
                      input_fn,
                      steps,
                      feed_fn=None,
                      metrics=None,
                      name=''):
    if self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset'):
      return

    checkpoint_path = saver.latest_checkpoint(self._model_dir)
    eval_dir = os.path.join(self._model_dir, 'eval' if not name else
                            'eval_' + name)
    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      global_step = contrib_framework.create_global_step(g)
      features, targets = input_fn()
      self._check_inputs(features, targets)
      eval_dict = self._get_eval_ops(features, targets,
                                     metrics if metrics is not None else
                                     self._get_default_metric_functions())
      update_op, eval_dict = self._extract_metric_update_ops(eval_dict)
      eval_results, _ = evaluate(graph=g,
                                 output_dir=eval_dir,
                                 checkpoint_path=checkpoint_path,
                                 eval_dict=eval_dict,
                                 update_op=update_op,
                                 global_step_tensor=global_step,
                                 supervisor_master=self._config.master,
                                 feed_fn=feed_fn,
                                 max_steps=steps)
      return eval_results
开发者ID:01-,项目名称:tensorflow,代码行数:31,代码来源:estimator.py


示例11: _infer_model

  def _infer_model(
      self, input_fn, feed_fn=None, outputs=None, as_iterable=False):
    # Check that model has been trained.
    checkpoint_path = saver.latest_checkpoint(self._model_dir)
    if not checkpoint_path:
      raise NotFittedError("Couldn't find trained model at %s."
                           % self._model_dir)

    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      contrib_framework.create_global_step(g)
      features = self._get_features_from_input_fn(input_fn)
      predictions = self._get_predict_ops(features)
      # If predictions is single output - wrap it into dict, and remember to
      # return not a dict.
      return_dict = isinstance(predictions, dict)
      if not return_dict:
        predictions = {'predictions': predictions}

      # Filter what to run predictions on, if outputs provided.
      if outputs:
        existing_keys = predictions.keys()
        predictions = {
            key: value for key, value in predictions.items() if key in outputs
        }
        if not predictions:
          raise ValueError('Expected to run at least one output from %s, '
                           'provided %s.' % (existing_keys, outputs))

      if as_iterable:
        return self._infer_model_as_iterable(
            checkpoint_path, predictions, feed_fn, return_dict)
      else:
        return self._infer_model_single(
            checkpoint_path, predictions, feed_fn, return_dict)
开发者ID:Nishant23,项目名称:tensorflow,代码行数:35,代码来源:estimator.py


示例12: _infer_model

  def _infer_model(self,
                   x=None, input_fn=None, feed_fn=None,
                   batch_size=None, axis=None, proba=False):
    # Converts inputs into tf.DataFrame / tf.Series.
    batch_size = -1 if batch_size is None else batch_size
    if x is not None:
      input_fn, feed_fn = _get_predict_input_fn(x, batch_size)

    checkpoint_path = saver.latest_checkpoint(self._model_dir)
    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      contrib_framework.create_global_step(g)
      features, _ = input_fn()
      predictions = self._get_predict_ops(features)
      if not isinstance(predictions, dict):
        predictions = {'predictions': predictions}
      # TODO(ipolosukhin): Support batching
      if feed_fn is None:
        return infer(checkpoint_path, predictions)
      preds = {}
      while True:
        try:
          feed_dict = feed_fn()
        except StopIteration:
          break
        if feed_dict is None:
          break
        outputs = infer(checkpoint_path, predictions, feed_dict=feed_dict)
        for key in outputs:
          if key not in preds:
            preds[key] = []
          preds[key].append(outputs[key])
      for key in preds:
        preds[key] = np.concatenate(preds[key], axis=0)
      return preds
开发者ID:01-,项目名称:tensorflow,代码行数:35,代码来源:estimator.py


示例13: testUsageGraph

 def testUsageGraph(self):
   """Expected usage when graph building."""
   with context.graph_mode():
     num_training_steps = 10
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     for training_continuation in range(3):
       with ops.Graph().as_default():
         network = MyNetwork()
         optimizer = adam.AdamOptimizer(0.001)
         root = checkpointable_utils.Checkpoint(
             optimizer=optimizer, network=network,
             global_step=training_util.get_or_create_global_step())
         input_value = constant_op.constant([[3.]])
         train_op = optimizer.minimize(
             network(input_value),
             global_step=root.global_step)
         checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
         with self.test_session(graph=ops.get_default_graph()) as session:
           status = root.restore(save_path=checkpoint_path)
           status.initialize_or_restore(session=session)
           if checkpoint_path is None:
             self.assertEqual(0, training_continuation)
             with self.assertRaises(AssertionError):
               status.assert_consumed()
           else:
             status.assert_consumed()
           for _ in range(num_training_steps):
             session.run(train_op)
           root.save(file_prefix=checkpoint_prefix, session=session)
           self.assertEqual((training_continuation + 1) * num_training_steps,
                            session.run(root.global_step))
           self.assertEqual(training_continuation + 1,
                            session.run(root.save_counter))
开发者ID:neuroradiology,项目名称:tensorflow,代码行数:34,代码来源:checkpointable_utils_test.py


示例14: create_session

 def create_session(self, checkpoint_dir):
     """Creates a MonitoredSession for this predictor."""
     checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
     return training.MonitoredSession(
         session_creator=training.ChiefSessionCreator(
             checkpoint_filename_with_path=checkpoint_path,
             config=self._session_config()))
开发者ID:ucam-smt,项目名称:sgnmt,代码行数:7,代码来源:tf_nizza.py


示例15: testDeferredRestorationUsageEager

 def testDeferredRestorationUsageEager(self):
   """An idiomatic eager execution example."""
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   latest_object_graph = None  # Will be saved with the checkpoint eventually.
   for training_continuation in range(3):
     with ops.Graph().as_default():
       network = MyNetwork()
       optimizer = CheckpointableAdam(0.001)
       root = Root(optimizer=optimizer, network=network)
       checkpointable.restore(
           save_path=core_saver.latest_checkpoint(checkpoint_directory),
           root_checkpointable=root,
           object_graph_proto=latest_object_graph)
       for _ in range(num_training_steps):
         # TODO(allenl): Use a Dataset and serialize/checkpoint it.
         input_value = constant_op.constant([[3.]])
         optimizer.minimize(
             lambda: network(input_value),  # pylint: disable=cell-var-from-loop
             global_step=root.global_step)
       latest_object_graph, _ = checkpointable.save(
           file_prefix=checkpoint_prefix,
           root_checkpointable=root)
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        root.global_step.numpy())
开发者ID:japrogramer,项目名称:tensorflow,代码行数:26,代码来源:checkpointable_test.py


示例16: predict

  def predict(self, input_fn, predict_keys=None, hooks=None, checkpoint_path=None):
    """Returns predictions for given features.

    Args:
      input_fn: Input function returning features which is a dictionary of
        string feature name to `Tensor` or `SparseTensor`. If it returns a
        tuple, first item is extracted as features. Prediction continues until
        `input_fn` raises an end-of-input exception (`OutOfRangeError` or
        `StopIteration`).
      predict_keys: list of `str`, name of the keys to predict. It is used if
        the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used
        then rest of the predictions will be filtered from the dictionary. If
        `None`, returns all.
      hooks: List of `SessionRunHook` subclass instances. Used for callbacks
        inside the prediction call.
      checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
        latest checkpoint in `model_dir` is used.

    Yields:
      Evaluated values of `predictions` tensors.

    Raises:
      ValueError: Could not find a trained model in model_dir.
      ValueError: if batch length of predictions are not same.
      ValueError: If there is a conflict between `predict_keys` and
        `predictions`. For example if `predict_keys` is not `None` but
        `EstimatorSpec.predictions` is not a `dict`.
    """
    hooks = _check_hooks_type(hooks)
    # Check that model has been trained.
    if not checkpoint_path:
      checkpoint_path = saver.latest_checkpoint(self._model_dir)
    if not checkpoint_path:
      raise ValueError('Could not find trained model in model_dir: {}.'.format(
          self._model_dir))

    with ops.Graph().as_default() as g:
      random_seed.set_random_seed(self._config.tf_random_seed)
      training.create_global_step(g)
      features = self._get_features_from_input_fn(input_fn)
      estimator_spec = self._call_model_fn(features, None,
                                           model_fn_lib.ModeKeys.PREDICT)
      predictions = self._extract_keys(estimator_spec.predictions, predict_keys)
      with training.MonitoredSession(
          session_creator=training.ChiefSessionCreator(
              checkpoint_filename_with_path=checkpoint_path,
              scaffold=estimator_spec.scaffold,
              config=config_pb2.ConfigProto(allow_soft_placement=True)),
          hooks=hooks) as mon_sess:
        while not mon_sess.should_stop():
          preds_evaluated = mon_sess.run(predictions)
          if not isinstance(predictions, dict):
            for pred in preds_evaluated:
              yield pred
          else:
            for i in range(self._extract_batch_length(preds_evaluated)):
              yield {
                  key: value[i]
                  for key, value in six.iteritems(preds_evaluated)
              }
开发者ID:LugarkPirog,项目名称:tensorflow,代码行数:60,代码来源:estimator.py


示例17: testAgnosticUsage

 def testAgnosticUsage(self):
   """Graph/eager agnostic usage."""
   # Does create garbage when executing eagerly due to ops.Graph() creation.
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   for training_continuation in range(3):
     with ops.Graph().as_default(), self.test_session(
         graph=ops.get_default_graph()):
       network = MyNetwork()
       optimizer = adam.AdamOptimizer(0.001)
       root = checkpointable_utils.Checkpoint(
           optimizer=optimizer, network=network,
           global_step=training_util.get_or_create_global_step())
       checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
       status = root.restore(save_path=checkpoint_path)
       input_value = constant_op.constant([[3.]])
       train_fn = functools.partial(
           optimizer.minimize,
           functools.partial(network, input_value),
           global_step=root.global_step)
       if context.in_graph_mode():
         train_fn = functools.partial(self.evaluate, train_fn())
       status.initialize_or_restore()
       for _ in range(num_training_steps):
         train_fn()
       root.save(file_prefix=checkpoint_prefix)
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        self.evaluate(root.global_step))
       self.assertEqual(training_continuation + 1,
                        self.evaluate(root.save_counter))
开发者ID:neuroradiology,项目名称:tensorflow,代码行数:31,代码来源:checkpointable_utils_test.py


示例18: create_session

def create_session(checkpoint_path, n_cpu_threads=-1):
    """Creates a MonitoredSession.
    
    Args:
      checkpoint_path (string): Path either to checkpoint directory or
                                directly to a checkpoint file.
      n_cpu_threads (int): Number of CPU threads. If negative, we
                           assume either GPU decoding or that all
                           CPU cores can be used.
    Returns:
      A TensorFlow MonitoredSession.
    """
    try:
        if os.path.isdir(checkpoint_path):
            checkpoint_path = saver.latest_checkpoint(checkpoint_path)
        else:
            logging.info("%s is not a directory. Interpreting as direct "
                         "path to checkpoint..." % checkpoint_path)
        return training.MonitoredSession(
            session_creator=training.ChiefSessionCreator(
                checkpoint_filename_with_path=checkpoint_path,
                config=session_config(n_cpu_threads)))
    except tf.errors.NotFoundError as e:
        logging.fatal("Could not find all variables of the computation "
            "graph in the T2T checkpoint file. This means that the "
            "checkpoint does not correspond to the model specified in "
            "SGNMT. Please double-check pred_src_vocab_size, "
            "pred_trg_vocab_size, and all the t2t_* parameters. "
            "Also make sure that the checkpoint exists and is readable")
        raise AttributeError("Could not initialize TF session.")
开发者ID:ucam-smt,项目名称:sgnmt,代码行数:30,代码来源:tf_utils.py


示例19: every_n_step_end

  def every_n_step_end(self, step, outputs):
    super(ValidationMonitor, self).every_n_step_end(step, outputs)
    # TODO(mdan): The use of step below is probably misleading.
    # The code should probably use the step from the checkpoint, because
    # that's what is being evaluated.
    if self._estimator is None:
      raise ValueError("Missing call to set_estimator.")
    current_time = time.time()
    if (self._check_interval_secs is not None and
        self._last_checkpoint_check_time is not None and
        current_time - self._last_checkpoint_check_time <=
        self._check_interval_secs):
      logging.debug(
          "Skipping evaluation since less than %d seconds have passed since "
          "last check for a new checkpoint.", self._check_interval_secs)
      return False
    self._last_checkpoint_check_time = current_time
    # Check that we are not running evaluation on the same checkpoint.
    latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
    if latest_path is None:
      logging.debug("Skipping evaluation since model has not been saved yet "
                    "at step %d.", step)
      return False
    if latest_path is not None and latest_path == self._latest_path:
      logging.debug("Skipping evaluation due to same checkpoint %s for step %d "
                    "as for step %d.", latest_path, step,
                    self._latest_path_step)
      return False
    self._latest_path = latest_path
    self._latest_path_step = step

    # Run evaluation and log it.
    validation_outputs = self._evaluate_estimator()
    stats = []
    for name in validation_outputs:
      stats.append("%s = %s" % (name, str(validation_outputs[name])))
    logging.info("Validation (step %d): %s", step, ", ".join(stats))

    # Early stopping logic.
    if self.early_stopping_rounds is not None:
      if self.early_stopping_metric not in validation_outputs:
        raise ValueError("Metric %s missing from outputs %s." %
                         (self.early_stopping_metric,
                          set(validation_outputs.keys())))
      current_value = validation_outputs[self.early_stopping_metric]
      if (self._best_value is None or (self.early_stopping_metric_minimize and
                                       (current_value < self._best_value)) or
          (not self.early_stopping_metric_minimize and
           (current_value > self._best_value))):
        self._best_value = current_value
        self._best_metrics = copy.deepcopy(validation_outputs)
        self._best_value_step = step
      stop_now = (step - self._best_value_step >= self.early_stopping_rounds)
      if stop_now:
        logging.info("Stopping. Best step: {} with {} = {}.".format(
            self._best_value_step, self.early_stopping_metric,
            self._best_value))
        self._early_stopped = True
        return True
    return False
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:60,代码来源:monitors.py


示例20: export_fn

  def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None):
    """Exports the given Estimator as a SavedModel.

    Args:
      estimator: the Estimator to export.
      export_dir_base: A string containing a directory to write the exported
        graph and checkpoints.
      checkpoint_path: The checkpoint path to export.  If None (the default),
        the most recent checkpoint found within the model directory is chosen.
      eval_result: placehold args matching the call signature of ExportStrategy.

    Returns:
      The string path to the exported directory.
    """
    if not checkpoint_path:
      # TODO(b/67425018): switch to
      #    checkpoint_path = estimator.latest_checkpoint()
      #  as soon as contrib is cleaned up and we can thus be sure that
      #  estimator is a tf.estimator.Estimator and not a
      #  tf.contrib.learn.Estimator
      checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
    export_checkpoint_path, export_eval_result = best_model_selector.update(
        checkpoint_path, eval_result)

    if export_checkpoint_path and export_eval_result is not None:
      checkpoint_base = os.path.basename(export_checkpoint_path)
      export_dir = os.path.join(export_dir_base, checkpoint_base)
      return best_model_export_strategy.export(
          estimator, export_dir, export_checkpoint_path, export_eval_result)
    else:
      return ''
开发者ID:Lin-jipeng,项目名称:tensorflow,代码行数:31,代码来源:saved_model_export_utils.py



注:本文中的tensorflow.python.training.saver.latest_checkpoint函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python slot_creator.create_slot函数代码示例发布时间:2022-05-27
下一篇:
Python saver.import_meta_graph函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap