• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python checkpoint_utils.load_variable函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.contrib.framework.python.framework.checkpoint_utils.load_variable函数的典型用法代码示例。如果您正苦于以下问题:Python load_variable函数的具体用法?Python load_variable怎么用?Python load_variable使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了load_variable函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: test_save_steps_saves_periodically

 def test_save_steps_saves_periodically(self):
   with self.graph.as_default():
     monitor = learn.monitors.CheckpointSaver(
         self.model_dir, save_steps=2, scaffold=self.scaffold)
     monitor.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       self._run(monitor, 1, self.train_op, sess)
       self._run(monitor, 2, self.train_op, sess)
       # Not saved
       self.assertEqual(1,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       self._run(monitor, 3, self.train_op, sess)
       # saved
       self.assertEqual(3,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       self._run(monitor, 4, self.train_op, sess)
       # Not saved
       self.assertEqual(3,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       self._run(monitor, 5, self.train_op, sess)
       # saved
       self.assertEqual(5,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:29,代码来源:monitors_test.py


示例2: test_save_steps_saves_periodically

 def test_save_steps_saves_periodically(self):
   with self.graph.as_default():
     hook = basic_session_run_hooks.CheckpointSaverHook(
         self.model_dir, save_steps=2, scaffold=self.scaffold)
     hook.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       mon_sess = monitored_session._HookedSession(sess, [hook])
       mon_sess.run(self.train_op)
       mon_sess.run(self.train_op)
       # Not saved
       self.assertEqual(1,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       mon_sess.run(self.train_op)
       # saved
       self.assertEqual(3,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       mon_sess.run(self.train_op)
       # Not saved
       self.assertEqual(3,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       mon_sess.run(self.train_op)
       # saved
       self.assertEqual(5,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:30,代码来源:basic_session_run_hooks_test.py


示例3: test_train_max_steps_is_not_incremental

  def test_train_max_steps_is_not_incremental(self):
    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      learn.graph_actions.train(
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          max_steps=10)
      step = checkpoint_utils.load_variable(
          self._output_dir, variables_lib.get_global_step().name)
      self.assertEqual(10, step)

    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      learn.graph_actions.train(
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          max_steps=15)
      step = checkpoint_utils.load_variable(
          self._output_dir, variables_lib.get_global_step().name)
      self.assertEqual(15, step)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:26,代码来源:graph_actions_test.py


示例4: test_train_skip_train_if_max_step_already_saved

  def test_train_skip_train_if_max_step_already_saved(self):
    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      learn.graph_actions._monitored_train(  # pylint: disable=protected-access
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          max_steps=10)
      step = checkpoint_utils.load_variable(
          self._output_dir, variables_lib.get_global_step().name)
      self.assertEqual(10, step)

    with ops.Graph().as_default() as g, self.test_session(g):
      with ops.control_dependencies(self._build_inference_graph()):
        train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
      learn.graph_actions._monitored_train(  # pylint: disable=protected-access
          g,
          output_dir=self._output_dir,
          train_op=train_op,
          loss_op=constant_op.constant(2.0),
          max_steps=10)
      step = checkpoint_utils.load_variable(
          self._output_dir, variables_lib.get_global_step().name)
      self.assertEqual(10, step)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:26,代码来源:graph_actions_test.py


示例5: testGetTensor

 def testGetTensor(self):
   checkpoint_dir = self.get_temp_dir()
   with self.test_session() as session:
     v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
   self.assertAllEqual(
       checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1)
   self.assertAllEqual(
       checkpoint_utils.load_variable(checkpoint_dir, "var2"), v2)
   self.assertAllEqual(
       checkpoint_utils.load_variable(checkpoint_dir, "var3"), v3)
   self.assertAllEqual(
       checkpoint_utils.load_variable(checkpoint_dir, "useful_scope/var4"), v4)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:12,代码来源:checkpoint_utils_test.py


示例6: print_tensors_in_checkpoint_file

def print_tensors_in_checkpoint_file(file_name, tensor_name):
  """Prints tensors in a checkpoint file.

  If no `tensor_name` is provided, prints the tensor names and shapes
  in the checkpoint file.

  If `tensor_name` is provided, prints the content of the tensor.

  Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
  """
  try:
    if not tensor_name:
      variables = checkpoint_utils.list_variables(file_name)
      for name, shape in variables:
        print("%s\t%s" % (name, str(shape)))
    else:
      print("tensor_name: ", tensor_name)
      print(checkpoint_utils.load_variable(file_name, tensor_name))
  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
    if "corrupted compressed block contents" in str(e):
      print("It's likely that your checkpoint file has been compressed "
            "with SNAPPY.")
开发者ID:kadeng,项目名称:tensorflow,代码行数:25,代码来源:inspect_checkpoint.py


示例7: testNoTensor

 def testNoTensor(self):
   checkpoint_dir = self.get_temp_dir()
   with self.test_session() as session:
     _, _, _, _ = _create_checkpoints(session, checkpoint_dir)
   with self.assertRaises(errors_impl.OpError):
     self.assertAllEqual(
         checkpoint_utils.load_variable(checkpoint_dir, "var5"), [])
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:7,代码来源:checkpoint_utils_test.py


示例8: test_save_secs_saves_periodically

  def test_save_secs_saves_periodically(self, mock_time):
    # Let's have a realistic start time
    current_time = 1484695987.209386

    with self.graph.as_default():
      mock_time.return_value = current_time
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir, save_secs=2, scaffold=self.scaffold)
      hook.begin()
      self.scaffold.finalize()

      with session_lib.Session() as sess:
        sess.run(self.scaffold.init_op)
        mon_sess = monitored_session._HookedSession(sess, [hook])

        mock_time.return_value = current_time
        mon_sess.run(self.train_op)  # Saved.

        mock_time.return_value = current_time + 0.5
        mon_sess.run(self.train_op)  # Not saved.

        self.assertEqual(1,
                         checkpoint_utils.load_variable(self.model_dir,
                                                        self.global_step.name))

        # Simulate 2.5 seconds of sleep.
        mock_time.return_value = current_time + 2.5
        mon_sess.run(self.train_op)  # Saved.

        mock_time.return_value = current_time + 2.6
        mon_sess.run(self.train_op)  # Not saved.

        mock_time.return_value = current_time + 2.7
        mon_sess.run(self.train_op)  # Not saved.

        self.assertEqual(3,
                         checkpoint_utils.load_variable(self.model_dir,
                                                        self.global_step.name))

        # Simulate 7.5 more seconds of sleep (10 seconds from start.
        mock_time.return_value = current_time + 10
        mon_sess.run(self.train_op)  # Saved.
        self.assertEqual(6,
                         checkpoint_utils.load_variable(self.model_dir,
                                                        self.global_step.name))
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:45,代码来源:basic_session_run_hooks_test.py


示例9: test_saves_when_saver_and_scaffold_both_missing

 def test_saves_when_saver_and_scaffold_both_missing(self):
   with self.graph.as_default():
     hook = basic_session_run_hooks.CheckpointSaverHook(
         self.model_dir, save_steps=1)
     hook.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       mon_sess = monitored_session._HookedSession(sess, [hook])
       mon_sess.run(self.train_op)
       self.assertEqual(1,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
开发者ID:1000sprites,项目名称:tensorflow,代码行数:13,代码来源:basic_session_run_hooks_test.py


示例10: test_save_saves_at_end

 def test_save_saves_at_end(self):
   with self.graph.as_default():
     monitor = learn.monitors.CheckpointSaver(
         self.model_dir, save_secs=2, scaffold=self.scaffold)
     monitor.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       self._run(monitor, 1, self.train_op, sess)
       self._run(monitor, 2, self.train_op, sess)
       monitor.end(sess)
       self.assertEqual(2,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:14,代码来源:monitors_test.py


示例11: covariances

 def covariances(self):
   """Returns the covariances."""
   return checkpoint_utils.load_variable(
       self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_COVS_VARIABLE)
开发者ID:ivankreso,项目名称:tensorflow,代码行数:4,代码来源:gmm.py


示例12: clusters

 def clusters(self):
   """Returns cluster centers."""
   clusters = checkpoint_utils.load_variable(
       self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_VARIABLE)
   return np.squeeze(clusters, 1)
开发者ID:ivankreso,项目名称:tensorflow,代码行数:5,代码来源:gmm.py


示例13: weights

 def weights(self):
   """Returns the cluster weights."""
   return checkpoint_utils.load_variable(
       self.model_dir, gmm_ops.GmmAlgorithm.CLUSTERS_WEIGHT)
开发者ID:Immexxx,项目名称:tensorflow,代码行数:4,代码来源:gmm.py


示例14: load_variable

def load_variable(checkpoint_dir, name):
  """See `tf.contrib.framework.load_variable`."""
  return checkpoint_utils.load_variable(checkpoint_dir, name)
开发者ID:2020zyc,项目名称:tensorflow,代码行数:3,代码来源:checkpoints.py


示例15: testNoCheckpoints

 def testNoCheckpoints(self):
   checkpoint_dir = self.get_temp_dir() + "/no_checkpoints"
   with self.assertRaises(errors_impl.OpError):
     self.assertAllEqual(
         checkpoint_utils.load_variable(checkpoint_dir, "var1"), [])
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:5,代码来源:checkpoint_utils_test.py



注:本文中的tensorflow.contrib.framework.python.framework.checkpoint_utils.load_variable函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python ops.arg_scope函数代码示例发布时间:2022-05-27
下一篇:
Python framework.load_variable函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap