• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python training_util.get_or_create_global_step函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.training.training_util.get_or_create_global_step函数的典型用法代码示例。如果您正苦于以下问题:Python get_or_create_global_step函数的具体用法?Python get_or_create_global_step怎么用?Python get_or_create_global_step使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了get_or_create_global_step函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: run_benchmark

def run_benchmark(sess, init_op, add_op):
  """Returns MB/s rate of addition."""


  logdir=FLAGS.logdir_prefix+'/'+FLAGS.name
  os.system('mkdir -p '+logdir)
  
  # TODO: make events follow same format as eager writer
  writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(logdir+'/events'))
  filename = compat.as_text(writer.FileName())
  training_util.get_or_create_global_step()

  sess.run(init_op)

  for step in range(FLAGS.iters):
    start_time = time.time()
    for i in range(FLAGS.iters_per_step):
      sess.run(add_op.op)

    elapsed_time = time.time() - start_time
    rate = float(FLAGS.iters)*FLAGS.data_mb/elapsed_time
    event = make_event('rate', rate, step)
    writer.WriteEvent(event)
    writer.Flush()
  writer.Close()
开发者ID:yaroslavvb,项目名称:stuff,代码行数:25,代码来源:benchmark_grpc_recv.py


示例2: _test_logits_helper

 def _test_logits_helper(self, mode):
   """Tests that the expected logits are passed to mock head."""
   with ops.Graph().as_default():
     training_util.get_or_create_global_step()
     generator_inputs = {'x': array_ops.zeros([5, 4])}
     real_data = (None if mode == model_fn_lib.ModeKeys.PREDICT else
                  array_ops.zeros([5, 4]))
     generator_scope_name = 'generator'
     head = mock_head(self,
                      expected_generator_inputs=generator_inputs,
                      expected_real_data=real_data,
                      generator_scope_name=generator_scope_name)
     estimator_spec = estimator._gan_model_fn(
         features=generator_inputs,
         labels=real_data,
         mode=mode,
         generator_fn=generator_fn,
         discriminator_fn=discriminator_fn,
         generator_scope_name=generator_scope_name,
         head=head)
     with monitored_session.MonitoredTrainingSession(
         checkpoint_dir=self._model_dir) as sess:
       if mode == model_fn_lib.ModeKeys.TRAIN:
         sess.run(estimator_spec.train_op)
       elif mode == model_fn_lib.ModeKeys.EVAL:
         sess.run(estimator_spec.loss)
       elif mode == model_fn_lib.ModeKeys.PREDICT:
         sess.run(estimator_spec.predictions)
       else:
         self.fail('Invalid mode: {}'.format(mode))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:30,代码来源:gan_estimator_test.py


示例3: testGraphSummary

 def testGraphSummary(self):
   training_util.get_or_create_global_step()
   name = 'hi'
   graph = graph_pb2.GraphDef(node=(node_def_pb2.NodeDef(name=name),))
   with self.test_session():
     with self.create_db_writer().as_default():
       summary_ops.initialize(graph=graph)
   six.assertCountEqual(self, [name],
                        get_all(self.db, 'SELECT node_name FROM Nodes'))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:9,代码来源:summary_ops_graph_test.py


示例4: testEagerMemory

 def testEagerMemory(self):
   training_util.get_or_create_global_step()
   logdir = self.get_temp_dir()
   with summary_ops.create_file_writer(
       logdir, max_queue=0,
       name='t0').as_default(), summary_ops.always_record_summaries():
     summary_ops.generic('tensor', 1, '')
     summary_ops.scalar('scalar', 2.0)
     summary_ops.histogram('histogram', [1.0])
     summary_ops.image('image', [[[[1.0]]]])
     summary_ops.audio('audio', [[1.0]], 1.0, 1)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:11,代码来源:summary_ops_test.py


示例5: testSummaryName

  def testSummaryName(self):
    training_util.get_or_create_global_step()
    logdir = tempfile.mkdtemp()
    with summary_ops.create_file_writer(
        logdir, max_queue=0,
        name='t2').as_default(), summary_ops.always_record_summaries():

      summary_ops.scalar('scalar', 2.0)

      events = summary_test_util.events_from_logdir(logdir)
      self.assertEqual(len(events), 2)
      self.assertEqual(events[1].summary.value[0].tag, 'scalar')
开发者ID:AnishShah,项目名称:tensorflow,代码行数:12,代码来源:summary_ops_test.py


示例6: testWriteSummaries

  def testWriteSummaries(self):
    e = SimpleEvaluator(IdentityModel())
    e(3.0)
    e([5.0, 7.0, 9.0])
    training_util.get_or_create_global_step()
    logdir = tempfile.mkdtemp()

    e.all_metric_results(logdir)

    events = summary_test_util.events_from_file(logdir)
    self.assertEqual(len(events), 2)
    self.assertEqual(events[1].summary.value[0].simple_value, 6.0)
开发者ID:SylChan,项目名称:tensorflow,代码行数:12,代码来源:evaluator_test.py


示例7: testWriteSummaries

  def testWriteSummaries(self):
    m = metrics.Mean()
    m([1, 10, 100])
    training_util.get_or_create_global_step()
    logdir = tempfile.mkdtemp()
    with summary_ops.create_file_writer(
        logdir, max_queue=0,
        name="t0").as_default(), summary_ops.always_record_summaries():
      m.result()  # As a side-effect will write summaries.

    events = summary_test_util.events_from_logdir(logdir)
    self.assertEqual(len(events), 2)
    self.assertEqual(events[1].summary.value[0].simple_value, 37.0)
开发者ID:neuroradiology,项目名称:tensorflow,代码行数:13,代码来源:metrics_test.py


示例8: testSummaryOps

 def testSummaryOps(self):
   training_util.get_or_create_global_step()
   logdir = tempfile.mkdtemp()
   summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0')
   summary_ops.always_record_summaries()
   summary_ops.generic('tensor', 1, '')
   summary_ops.scalar('scalar', 2.0)
   summary_ops.histogram('histogram', [1.0])
   summary_ops.image('image', [[[[1.0]]]])
   summary_ops.audio('audio', [[1.0]], 1.0, 1)
   # The working condition of the ops is tested in the C++ test so we just
   # test here that we're calling them correctly.
   self.assertTrue(gfile.Exists(logdir))
开发者ID:DjangoPeng,项目名称:tensorflow,代码行数:13,代码来源:summary_ops_test.py


示例9: testWriteSummariesGraph

  def testWriteSummariesGraph(self):
    with context.graph_mode(), ops.Graph().as_default(), self.test_session():
      e = SimpleEvaluator(IdentityModel())
      ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0])
      training_util.get_or_create_global_step()
      logdir = tempfile.mkdtemp()
      init_op, call_op, results_op = e.evaluate_on_dataset(
          ds, summary_logdir=logdir)
      variables.global_variables_initializer().run()
      e.run_evaluation(init_op, call_op, results_op)

    events = summary_test_util.events_from_file(logdir)
    self.assertEqual(len(events), 2)
    self.assertEqual(events[1].summary.value[0].simple_value, 6.0)
开发者ID:SylChan,项目名称:tensorflow,代码行数:14,代码来源:evaluator_test.py


示例10: testSummaryGlobalStep

 def testSummaryGlobalStep(self):
   training_util.get_or_create_global_step()
   logdir = self.get_temp_dir()
   writer = summary_ops.create_file_writer(logdir, max_queue=0)
   with writer.as_default(), summary_ops.always_record_summaries():
     summary_ops.scalar('scalar', 2.0)
   with self.cached_session() as sess:
     sess.run(variables.global_variables_initializer())
     sess.run(summary_ops.summary_writer_initializer_op())
     step, _ = sess.run(
         [training_util.get_global_step(), summary_ops.all_summary_ops()])
   events = summary_test_util.events_from_logdir(logdir)
   self.assertEqual(2, len(events))
   self.assertEqual(step, events[1].step)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:14,代码来源:summary_ops_graph_test.py


示例11: testDefunSummarys

  def testDefunSummarys(self):
    training_util.get_or_create_global_step()
    logdir = tempfile.mkdtemp()
    with summary_ops.create_summary_file_writer(
        logdir, max_queue=0,
        name='t1').as_default(), summary_ops.always_record_summaries():

      @function.defun
      def write():
        summary_ops.scalar('scalar', 2.0)

      write()
      events = summary_test_util.events_from_logdir(logdir)
      self.assertEqual(len(events), 2)
      self.assertEqual(events[1].summary.value[0].simple_value, 2.0)
开发者ID:abidrahmank,项目名称:tensorflow,代码行数:15,代码来源:summary_ops_test.py


示例12: setUp

 def setUp(self):
   self.model_dir = tempfile.mkdtemp()
   self.graph = ops.Graph()
   with self.graph.as_default():
     self.scaffold = monitored_session.Scaffold()
     self.global_step = training_util.get_or_create_global_step()
     self.train_op = state_ops.assign_add(self.global_step, 1)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:7,代码来源:monitors_test.py


示例13: testAgnosticUsage

 def testAgnosticUsage(self):
   """Graph/eager agnostic usage."""
   # Does create garbage when executing eagerly due to ops.Graph() creation.
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   for training_continuation in range(3):
     with test_util.device(use_gpu=True):
       model = MyModel()
       optimizer = adam.AdamOptimizer(0.001)
       root = checkpointable_utils.Checkpoint(
           optimizer=optimizer, model=model,
           global_step=training_util.get_or_create_global_step())
       manager = checkpoint_management.CheckpointManager(
           root, checkpoint_directory, max_to_keep=1)
       status = root.restore(save_path=manager.latest_checkpoint)
       input_value = constant_op.constant([[3.]])
       train_fn = functools.partial(
           optimizer.minimize,
           functools.partial(model, input_value),
           global_step=root.global_step)
       if not context.executing_eagerly():
         train_fn = functools.partial(self.evaluate, train_fn())
       status.initialize_or_restore()
       for _ in range(num_training_steps):
         train_fn()
       manager.save()
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        self.evaluate(root.global_step))
       self.assertEqual(training_continuation + 1,
                        self.evaluate(root.save_counter))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:30,代码来源:util_with_v1_optimizers_test.py


示例14: testAgnosticUsage

 def testAgnosticUsage(self):
   """Graph/eager agnostic usage."""
   # Does create garbage when executing eagerly due to ops.Graph() creation.
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   for training_continuation in range(3):
     with ops.Graph().as_default(), self.test_session(
         graph=ops.get_default_graph()):
       network = MyNetwork()
       optimizer = CheckpointableAdam(0.001)
       root = Checkpoint(
           optimizer=optimizer, network=network,
           global_step=training_util.get_or_create_global_step())
       checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
       status = root.restore(save_path=checkpoint_path)
       input_value = constant_op.constant([[3.]])
       train_fn = functools.partial(
           optimizer.minimize,
           functools.partial(network, input_value),
           global_step=root.global_step)
       if context.in_graph_mode():
         train_fn = functools.partial(self.evaluate, train_fn())
       status.initialize_or_restore()
       for _ in range(num_training_steps):
         train_fn()
       root.save(file_prefix=checkpoint_prefix)
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        self.evaluate(root.global_step))
       self.assertEqual(training_continuation + 1,
                        self.evaluate(root.save_counter))
开发者ID:hhu-luqi,项目名称:tensorflow,代码行数:31,代码来源:checkpointable_utils_test.py


示例15: _clone_and_build_model

def _clone_and_build_model(mode,
                           keras_model,
                           custom_objects,
                           features=None,
                           labels=None):
  """Clone and build the given keras_model.

  Args:
    mode: training mode.
    keras_model: an instance of compiled keras model.
    custom_objects: Dictionary for custom objects.
    features:
    labels:

  Returns:
    The newly built model.
  """
  # Set to True during training, False for inference.
  K.set_learning_phase(mode == model_fn_lib.ModeKeys.TRAIN)

  # Clone keras model.
  input_tensors = None if features is None else _create_ordered_io(
      keras_model, features)
  if custom_objects:
    with CustomObjectScope(custom_objects):
      model = models.clone_model(keras_model, input_tensors=input_tensors)
  else:
    model = models.clone_model(keras_model, input_tensors=input_tensors)

  # Compile/Build model
  if mode is model_fn_lib.ModeKeys.PREDICT and not model.built:
    model.build()
  else:
    optimizer_config = keras_model.optimizer.get_config()
    optimizer = keras_model.optimizer.__class__.from_config(optimizer_config)
    optimizer.iterations = training_util.get_or_create_global_step()

    # Get list of outputs.
    if labels is None:
      target_tensors = None
    elif isinstance(labels, dict):
      target_tensors = _create_ordered_io(keras_model, labels, is_input=False)
    else:
      target_tensors = [
          _cast_tensor_to_floatx(
              sparse_tensor_lib.convert_to_tensor_or_sparse_tensor(labels))
      ]

    model.compile(
        optimizer,
        keras_model.loss,
        metrics=keras_model.metrics,
        loss_weights=keras_model.loss_weights,
        sample_weight_mode=keras_model.sample_weight_mode,
        weighted_metrics=keras_model.weighted_metrics,
        target_tensors=target_tensors)

  if isinstance(model, models.Sequential):
    model = model.model
  return model
开发者ID:keithc61,项目名称:tensorflow,代码行数:60,代码来源:estimator.py


示例16: testSummaryName

  def testSummaryName(self):
    training_util.get_or_create_global_step()
    logdir = tempfile.mkdtemp()
    summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t2')
    summary_ops.always_record_summaries()

    summary_ops.scalar('scalar', 2.0)

    self.assertTrue(gfile.Exists(logdir))
    files = gfile.ListDirectory(logdir)
    self.assertEqual(len(files), 1)
    records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
    self.assertEqual(len(records), 2)
    event = event_pb2.Event()
    event.ParseFromString(records[1])
    self.assertEqual(event.summary.value[0].tag, 'scalar')
开发者ID:DjangoPeng,项目名称:tensorflow,代码行数:16,代码来源:summary_ops_test.py


示例17: testGraphDistributionStrategy

  def testGraphDistributionStrategy(self):
    self.skipTest("b/121381184")
    num_training_steps = 10
    checkpoint_directory = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

    def _train_fn(optimizer, model):
      input_value = constant_op.constant([[3.]])
      return optimizer.minimize(
          functools.partial(model, input_value),
          global_step=root.optimizer_step)

    for training_continuation in range(3):
      with ops.Graph().as_default():
        strategy = mirrored_strategy.MirroredStrategy()
        with strategy.scope():
          model = MyModel()
          optimizer = adam.AdamOptimizer(0.001)
          root = checkpointable_utils.Checkpoint(
              optimizer=optimizer, model=model,
              optimizer_step=training_util.get_or_create_global_step())
          status = root.restore(checkpoint_management.latest_checkpoint(
              checkpoint_directory))
          train_op = strategy.extended.call_for_each_replica(
              functools.partial(_train_fn, optimizer, model))
          with self.session() as session:
            if training_continuation > 0:
              status.assert_consumed()
            status.initialize_or_restore()
            for _ in range(num_training_steps):
              session.run(train_op)
            root.save(file_prefix=checkpoint_prefix)
        self.assertEqual((training_continuation + 1) * num_training_steps,
                         root.optimizer_step.numpy())
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:34,代码来源:util_with_v1_optimizers_test.py


示例18: testAgnosticUsage

 def testAgnosticUsage(self):
   """Graph/eager agnostic usage."""
   # Does create garbage when executing eagerly due to ops.Graph() creation.
   num_training_steps = 10
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   for training_continuation in range(3):
     with ops.Graph().as_default(), self.test_session(
         graph=ops.get_default_graph()), test_util.device(use_gpu=True):
       model = MyModel()
       optimizer = adam.AdamOptimizer(0.001)
       root = util.Checkpoint(
           optimizer=optimizer, model=model,
           global_step=training_util.get_or_create_global_step())
       checkpoint_path = checkpoint_management.latest_checkpoint(
           checkpoint_directory)
       status = root.restore(save_path=checkpoint_path)
       input_value = constant_op.constant([[3.]])
       train_fn = functools.partial(
           optimizer.minimize,
           functools.partial(model, input_value),
           global_step=root.global_step)
       if not context.executing_eagerly():
         train_fn = functools.partial(self.evaluate, train_fn())
       status.initialize_or_restore()
       for _ in range(num_training_steps):
         train_fn()
       root.save(file_prefix=checkpoint_prefix)
       self.assertEqual((training_continuation + 1) * num_training_steps,
                        self.evaluate(root.global_step))
       self.assertEqual(training_continuation + 1,
                        self.evaluate(root.save_counter))
开发者ID:jackd,项目名称:tensorflow,代码行数:32,代码来源:checkpointable_utils_test.py


示例19: test_inv_update_thunks

  def test_inv_update_thunks(self):
    """Ensures inverse update ops run once per global_step."""
    with self._graph.as_default(), self.test_session() as sess:
      fisher_estimator = estimator.FisherEstimator(
          damping_fn=lambda: 0.2,
          variables=[self.weights],
          layer_collection=self.layer_collection,
          cov_ema_decay=0.0)

      # Construct op that updates one inverse per global step.
      global_step = training_util.get_or_create_global_step()
      inv_matrices = [
          matrix
          for fisher_factor in self.layer_collection.get_factors()
          for matrix in fisher_factor._inverses_by_damping.values()
      ]
      inv_update_op_thunks = fisher_estimator.inv_update_thunks
      inv_update_op = control_flow_ops.case(
          [(math_ops.equal(global_step, i), thunk)
           for i, thunk in enumerate(inv_update_op_thunks)])
      increment_global_step = global_step.assign_add(1)

      sess.run(variables.global_variables_initializer())
      initial_inv_values = sess.run(inv_matrices)

      # Ensure there's one update per inverse matrix. This is true as long as
      # there's no fan-in/fan-out or parameter re-use.
      self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))

      # Test is no-op if only 1 invariance matrix.
      assert len(inv_matrices) > 1

      # Assign each covariance matrix a value other than the identity. This
      # ensures that the inverse matrices are updated to something different as
      # well.
      cov_matrices = [
          fisher_factor.get_cov()
          for fisher_factor in self.layer_collection.get_factors()
      ]
      sess.run([
          cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0])))
          for cov_matrix in cov_matrices
      ])

      for i in range(len(inv_matrices)):
        # Compare new and old inverse values
        new_inv_values = sess.run(inv_matrices)
        is_inv_equal = [
            np.allclose(initial_inv_value, new_inv_value)
            for (initial_inv_value,
                 new_inv_value) in zip(initial_inv_values, new_inv_values)
        ]
        num_inv_equal = sum(is_inv_equal)

        # Ensure exactly one inverse matrix changes per step.
        self.assertEqual(num_inv_equal, len(inv_matrices) - i)

        # Run all inverse update ops.
        sess.run(inv_update_op)
        sess.run(increment_global_step)
开发者ID:QiangCai,项目名称:tensorflow,代码行数:60,代码来源:estimator_test.py


示例20: testUsageGraph

 def testUsageGraph(self):
   """Expected usage when graph building."""
   with context.graph_mode():
     num_training_steps = 10
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     for training_continuation in range(3):
       with ops.Graph().as_default():
         model = MyModel()
         optimizer = adam.AdamOptimizer(0.001)
         root = util.Checkpoint(
             optimizer=optimizer, model=model,
             global_step=training_util.get_or_create_global_step())
         input_value = constant_op.constant([[3.]])
         train_op = optimizer.minimize(
             model(input_value),
             global_step=root.global_step)
         checkpoint_path = checkpoint_management.latest_checkpoint(
             checkpoint_directory)
         with self.session(graph=ops.get_default_graph()) as session:
           status = root.restore(save_path=checkpoint_path)
           status.initialize_or_restore(session=session)
           if checkpoint_path is None:
             self.assertEqual(0, training_continuation)
             with self.assertRaises(AssertionError):
               status.assert_consumed()
           else:
             status.assert_consumed()
           for _ in range(num_training_steps):
             session.run(train_op)
           root.save(file_prefix=checkpoint_prefix, session=session)
           self.assertEqual((training_continuation + 1) * num_training_steps,
                            session.run(root.global_step))
           self.assertEqual(training_continuation + 1,
                            session.run(root.save_counter))
开发者ID:jackd,项目名称:tensorflow,代码行数:35,代码来源:checkpointable_utils_test.py



注:本文中的tensorflow.python.training.training_util.get_or_create_global_step函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python training_util.global_step函数代码示例发布时间:2022-05-27
下一篇:
Python training_util.get_global_step函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap