• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python network.save_network_checkpoint函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.contrib.eager.python.network.save_network_checkpoint函数的典型用法代码示例。如果您正苦于以下问题:Python save_network_checkpoint函数的具体用法?Python save_network_checkpoint怎么用?Python save_network_checkpoint使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了save_network_checkpoint函数的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testNetworkSaveRestoreAlreadyBuilt

 def testNetworkSaveRestoreAlreadyBuilt(self):
   net = MyNetwork(name="abcd")
   with self.assertRaisesRegexp(
       ValueError, "Attempt to save the Network before it was first called"):
     network.save_network_checkpoint(net, self.get_temp_dir())
   net(constant_op.constant([[2.0]]))
   self.evaluate(net.trainable_variables[0].assign([[17.0]]))
   self._save_modify_load_network_built(net, global_step=None)
   self._save_modify_load_network_built(net, global_step=10)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:9,代码来源:network_test.py


示例2: testRestoreIntoSubNetwork

  def testRestoreIntoSubNetwork(self):

    class Parent(network.Network):

      def __init__(self, name=None):
        super(Parent, self).__init__(name=name)
        self.first = self.track_layer(MyNetwork())
        self.second = self.track_layer(MyNetwork())

      def call(self, x):
        return self.first(self.second(x))

    one = constant_op.constant([[3.]])
    whole_model_saver = Parent()
    whole_model_saver(one)
    self.evaluate(whole_model_saver.variables[0].assign([[15.]]))
    self.evaluate(whole_model_saver.variables[1].assign([[16.]]))
    whole_model_checkpoint = network.save_network_checkpoint(
        whole_model_saver, self.get_temp_dir())

    save_from = MyNetwork()
    save_from(one)
    self.evaluate(save_from.variables[0].assign([[5.]]))
    checkpoint = network.save_network_checkpoint(save_from, self.get_temp_dir())
    save_into_parent = Parent()
    network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint)
    network.restore_network_checkpoint(save_into_parent.first, checkpoint)
    # deferred loading multiple times is fine
    network.restore_network_checkpoint(save_into_parent.first, checkpoint)
    save_into_parent(one)  # deferred loading
    self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[0]))
    self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1]))

    # Try again with the opposite ordering, and we should get different results
    # (deferred restoration should happen the same way non-deferred happens,
    # with later restorations overwriting older ones).
    save_into_parent = Parent()
    # deferred loading multiple times is fine
    network.restore_network_checkpoint(save_into_parent.first, checkpoint)
    network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint)
    save_into_parent(one)  # deferred loading
    # We've overwritten the sub-Network restore.
    self.assertAllEqual([[15.]], self.evaluate(save_into_parent.variables[0]))
    self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1]))

    self.evaluate(save_into_parent.variables[0].assign([[3.]]))
    self.evaluate(save_into_parent.variables[1].assign([[4.]]))
    network.restore_network_checkpoint(save_into_parent.second, checkpoint)
    self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[1]))
    with self.assertRaisesRegexp(errors_impl.NotFoundError,
                                 "not found in checkpoint"):
      # The checkpoint is incompatible.
      network.restore_network_checkpoint(save_into_parent, checkpoint)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:53,代码来源:network_test.py


示例3: testDefaultMapCollisionErrors

  def testDefaultMapCollisionErrors(self):

    one = constant_op.constant([[1.]])
    first = core.Dense(1, name="dense", use_bias=False)
    first(one)

    class Parent(network.Network):

      def __init__(self, name=None):
        super(Parent, self).__init__(name=name)
        self.first = self.track_layer(first)
        self.second = self.track_layer(core.Dense(1, use_bias=False))

      def call(self, x):
        return self.first(self.second(x))

    make_checkpoint = Parent()
    one = constant_op.constant([[1.]])
    make_checkpoint(one)
    self.evaluate(make_checkpoint.variables[0].assign([[2.]]))
    self.evaluate(make_checkpoint.variables[1].assign([[3.]]))
    with self.assertRaisesRegexp(
        ValueError,
        ("The default checkpoint variable name mapping strategy for Network "
         "'parent' resulted in a naming conflict.")):
      network.save_network_checkpoint(make_checkpoint, self.get_temp_dir())

    class Compatible(network.Network):

      def __init__(self, name=None):
        super(Compatible, self).__init__(name=name)
        self.first = self.track_layer(core.Dense(1, use_bias=False))

      def call(self, x):
        return self.first(x)

    successful_checkpoint = Compatible()
    successful_checkpoint(one)
    self.evaluate(successful_checkpoint.variables[0].assign([[-1.]]))
    checkpoint_path = network.save_network_checkpoint(
        successful_checkpoint, self.get_temp_dir())
    load_checkpoint = Parent()
    load_checkpoint(one)
    with self.assertRaisesRegexp(
        ValueError,
        ("The default checkpoint variable name mapping strategy for Network "
         "'parent_1' resulted in a naming conflict.")):
      network.restore_network_checkpoint(load_checkpoint, checkpoint_path)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:48,代码来源:network_test.py


示例4: testSaveRestoreDefaultGlobalStep

 def testSaveRestoreDefaultGlobalStep(self):
   net = MyNetwork(name="abcd")
   net(constant_op.constant([[2.0]]))
   self.evaluate(net.variables[0].assign([[3.]]))
   default_global_step = training_util.get_or_create_global_step()
   self.evaluate(default_global_step.assign(4242))
   save_path = network.save_network_checkpoint(net, self.get_temp_dir())
   self.assertIn("abcd-4242", save_path)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:8,代码来源:network_test.py


示例5: testCustomMapCollisionErrors

  def testCustomMapCollisionErrors(self):

    class Parent(network.Network):

      def __init__(self, name=None):
        super(Parent, self).__init__(name=name)
        self.first = self.track_layer(MyNetwork())
        self.second = self.track_layer(MyNetwork())

      def call(self, x):
        return self.first(self.second(x))

    make_checkpoint = Parent()
    one = constant_op.constant([[1.]])
    make_checkpoint(one)
    self.evaluate(make_checkpoint.variables[0].assign([[2.]]))
    self.evaluate(make_checkpoint.variables[1].assign([[3.]]))
    with self.assertRaisesRegexp(
        ValueError,
        "The map_func passed to save_network_checkpoint for the Network "
        "'parent' resulted in two variables named 'foo'"):
      network.save_network_checkpoint(
          make_checkpoint, self.get_temp_dir(), map_func=lambda n: "foo")
    checkpoint = network.save_network_checkpoint(
        network=make_checkpoint.first,
        save_path=self.get_temp_dir(),
        map_func=lambda n: "foo")
    loader = Parent()
    network.restore_network_checkpoint(
        loader, checkpoint, map_func=lambda n: "foo")
    with self.assertRaisesRegexp(
        ValueError,
        ("The map_func passed to restore_network_checkpoint for the Network"
         " 'parent_1' resulted in two variables named 'foo'")):
      loader(one)
    loader = Parent()
    loader(one)
    with self.assertRaisesRegexp(
        ValueError,
        ("The map_func passed to restore_network_checkpoint for the Network"
         " 'parent_2' resulted in two variables named 'foo'")):
      network.restore_network_checkpoint(
          loader, checkpoint, map_func=lambda n: "foo")
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:43,代码来源:network_test.py


示例6: testNetworkSaveAndRestoreIntoUnbuilt

 def testNetworkSaveAndRestoreIntoUnbuilt(self):
   save_dir = self.get_temp_dir()
   net1 = MyNetwork()
   test_input = constant_op.constant([[2.0]])
   net1(test_input)
   self.evaluate(net1.trainable_variables[0].assign([[17.0]]))
   save_path = network.save_network_checkpoint(net1, save_dir)
   # With a pre-build restore we should have the same value.
   net2 = MyNetwork()
   network.restore_network_checkpoint(net2, save_path)
   self.assertAllEqual(self.evaluate(net1(test_input)),
                       self.evaluate(net2(test_input)))
   self.assertIsNot(net1.variables[0], net2.variables[0])
   self.assertAllEqual(self.evaluate(net1.variables[0]),
                       self.evaluate(net2.variables[0]))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:15,代码来源:network_test.py


示例7: _save_modify_load_network_built

 def _save_modify_load_network_built(self, net, global_step=None):
   checkpoint_directory = self.get_temp_dir()
   checkpoint_path = network.save_network_checkpoint(
       network=net, save_path=checkpoint_directory, global_step=global_step)
   input_value = constant_op.constant([[42.0]])
   original_output = self.evaluate(net(input_value))
   for var in net.variables:
     self.evaluate(var.assign(var + 1.))
   self.assertGreater(
       self.evaluate(net(input_value)),
       original_output)
   # Either the returned explicit checkpoint path or the directory should work.
   network.restore_network_checkpoint(net, save_path=checkpoint_directory)
   self.assertAllEqual(
       original_output,
       self.evaluate(net(input_value)))
   for var in net.variables:
     self.evaluate(var.assign(var + 2.))
   network.restore_network_checkpoint(net, save_path=checkpoint_path)
   self.assertAllEqual(
       original_output,
       self.evaluate(net(input_value)))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:22,代码来源:network_test.py


示例8: testVariableScopeStripping

 def testVariableScopeStripping(self):
   with variable_scope.variable_scope("scope1"):
     with variable_scope.variable_scope("scope2"):
       net = MyNetwork()
   net(constant_op.constant([[2.0]]))
   self.evaluate(net.variables[0].assign([[42.]]))
   self.assertEqual(net.name, "scope1/scope2/my_network")
   self.assertStartsWith(
       expected_start="scope1/scope2/my_network/dense/",
       actual=net.trainable_weights[0].name)
   save_path = network.save_network_checkpoint(net, self.get_temp_dir())
   self.assertIn("scope1_scope2_my_network", save_path)
   restore_net = MyNetwork()
   # Delayed restoration
   network.restore_network_checkpoint(restore_net, save_path)
   restore_net(constant_op.constant([[1.0]]))
   self.assertAllEqual([[42.]],
                       self.evaluate(restore_net.variables[0]))
   self.evaluate(restore_net.variables[0].assign([[-1.]]))
   # Immediate restoration
   network.restore_network_checkpoint(restore_net, save_path)
   self.assertAllEqual([[42.]],
                       self.evaluate(restore_net.variables[0]))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:23,代码来源:network_test.py


示例9: testLoadIntoUnbuiltSharedLayer

  def testLoadIntoUnbuiltSharedLayer(self):

    class Owner(network.Network):

      def __init__(self, name=None):
        super(Owner, self).__init__(name=name)
        self.first = self.track_layer(core.Dense(
            1, name="first_layer", use_bias=False))

      def call(self, x):
        return self.first(x)

    first_owner = Owner()

    class User(network.Network):

      def __init__(self, use_layer, name=None):
        super(User, self).__init__(name=name)
        self.first = self.track_layer(use_layer)
        self.second = self.track_layer(core.Dense(
            1, name="second_layer", use_bias=False))

      def call(self, x):
        return self.second(self.first(x))

    class LikeUserButNotSharing(network.Network):

      def __init__(self, name=None):
        super(LikeUserButNotSharing, self).__init__(name=name)
        self.first = self.track_layer(core.Dense(
            1, name="first_layer", use_bias=False))
        self.second = self.track_layer(core.Dense(
            1, name="second_layer", use_bias=False))

      def call(self, x):
        return self.second(self.first(x))

    checkpoint_creator = LikeUserButNotSharing(name="checkpoint_creator")
    one = constant_op.constant([[1.0]])
    checkpoint_creator(one)
    self.assertEqual(2, len(checkpoint_creator.variables))
    self.evaluate(checkpoint_creator.variables[0].assign([[5.]]))
    self.evaluate(checkpoint_creator.variables[1].assign([[6.]]))
    # Re-map the variable names so that with default restore mapping we'll
    # attempt to restore into the unbuilt Layer.
    name_mapping = {
        "checkpoint_creator/first_layer/kernel": "owner/first_layer/kernel",
        "checkpoint_creator/second_layer/kernel": "second_layer/kernel",
    }
    save_path = network.save_network_checkpoint(
        checkpoint_creator,
        self.get_temp_dir(),
        map_func=lambda full_name: name_mapping[full_name])
    load_into = User(use_layer=first_owner.first)
    network.restore_network_checkpoint(load_into, save_path)
    self.assertEqual(0, len(first_owner.variables))
    self.assertAllEqual(self.evaluate(checkpoint_creator(one)),
                        self.evaluate(load_into(one)))
    self.assertEqual(1, len(first_owner.variables))
    self.assertAllEqual([[5.]], self.evaluate(load_into.variables[0]))
    self.assertAllEqual([[6.]], self.evaluate(load_into.variables[1]))
    first_owner(one)
    self.assertAllEqual([[5.]], self.evaluate(first_owner.variables[0]))

    # Try again with a garbage collected parent.
    first_owner = Owner()
    load_into = User(use_layer=first_owner.first)
    del first_owner
    gc.collect()
    def _restore_map_func(original_name):
      if original_name.startswith("owner/"):
        return original_name.replace("owner/", "owner_1/")
      else:
        return "user_1/" + original_name
    with self.assertRaisesRegexp(ValueError, "garbage collected"):
      network.restore_network_checkpoint(
          load_into, save_path, map_func=_restore_map_func)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:77,代码来源:network_test.py



注:本文中的tensorflow.contrib.eager.python.network.save_network_checkpoint函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python head.multi_label_head函数代码示例发布时间:2022-05-27
下一篇:
Python checkpointable_utils.add_variable函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap