• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python checkpoint_utils.init_from_checkpoint函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.training.checkpoint_utils.init_from_checkpoint函数的典型用法代码示例。如果您正苦于以下问题:Python init_from_checkpoint函数的具体用法?Python init_from_checkpoint怎么用?Python init_from_checkpoint使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了init_from_checkpoint函数的16个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testInitFromCheckpoint

  def testInitFromCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          my1 = variable_scope.get_variable("my1", [1, 10])
          with variable_scope.variable_scope("some_other_scope"):
            my2 = variable_scope.get_variable("my2", [10, 10])
            with variable_scope.variable_scope("other_useful_scope"):
              my4 = variable_scope.get_variable("var4", [9, 9])
        my3 = variable_scope.get_variable("my3", [100, 100])

        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
            "var1": "some_scope/my1",
            "useful_scope/": "some_scope/some_other_scope/other_useful_scope/",
        })
        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
            "var2": "some_scope/some_other_scope/my2",
            "var3": my3,
        })

        session.run(variables.global_variables_initializer())
        self.assertAllEqual(my1.eval(session), v1)
        self.assertAllEqual(my2.eval(session), v2)
        self.assertAllEqual(my3.eval(session), v3)
        self.assertAllEqual(my4.eval(session), v4)

        # Check that tensors are not explicitly in the graph.
        self.assertLess(len(str(session.graph.as_graph_def())), 29000)
开发者ID:QiangCai,项目名称:tensorflow,代码行数:33,代码来源:checkpoint_utils_test.py


示例2: testInitialValueComesFromCheckpoint

  def testInitialValueComesFromCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope(
            "some_scope", initializer=init_ops.zeros_initializer()):
          my1 = variable_scope.get_variable("my1", [1, 10])

        # At this point, my1.initialized_value() will add ops that reference
        # the zeros initializer of my1.
        before = variables.Variable(my1.initialized_value(), name="before")

        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1})

        # At this point, my1.initialized_value() will add ops that reference
        # the newly set initializer of my1.
        after = variables.Variable(my1.initialized_value(), name="after")

        session.run(variables.global_variables_initializer())
        self.assertAllEqual(session.run(my1), v1)
        self.assertAllEqual(session.run(my1.initialized_value()), v1)
        self.assertAllClose(session.run(before), [[0.0] * 10])
        self.assertAllClose(session.run(after), v1)
        with self.assertRaises(AssertionError):
          self.assertAllClose(session.run(before), session.run(after))
开发者ID:QiangCai,项目名称:tensorflow,代码行数:29,代码来源:checkpoint_utils_test.py


示例3: _warm_start_var

def _warm_start_var(var, prev_ckpt, prev_tensor_name=None):
  """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following:
      (i) `Variable`
      (ii) `ResourceVariable`
      (iii) list of `Variable`: The list must contain slices of the same larger
        variable.
      (iv) `PartitionedVariable`
    prev_ckpt: A string specifying the directory with checkpoint file(s) or path
      to checkpoint. The given checkpoint must have tensor with name
      `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.
  """
  if checkpoint_utils._is_variable(var):  # pylint: disable=protected-access
    current_var_name = _infer_var_name([var])
  elif (isinstance(var, list) and
        all(checkpoint_utils._is_variable(v) for v in var)):  # pylint: disable=protected-access
    current_var_name = _infer_var_name(var)
  elif isinstance(var, variables_lib.PartitionedVariable):
    current_var_name = _infer_var_name([var])
    var = var._get_variable_list()  # pylint: disable=protected-access
  else:
    raise TypeError(
        "var MUST be one of the following: a Variable, list of Variable or "
        "PartitionedVariable, but is {}".format(type(var)))
  if not prev_tensor_name:
    # Assume tensor name remains the same.
    prev_tensor_name = current_var_name
  checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var})
开发者ID:AnishShah,项目名称:tensorflow,代码行数:33,代码来源:warm_starting_util.py


示例4: testInitialValueComesFromCheckpoint

  def testInitialValueComesFromCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope(
            "some_scope", initializer=init_ops.zeros_initializer()):
          my1 = variable_scope.get_variable("my1", [1, 10])

        before = my1.initialized_value()

        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1})

        after = my1.initialized_value()

        self.assertAllEqual(session.run(before), [[0.0] * 10])
        self.assertAllEqual(session.run(after), v1)

        session.run(variables.global_variables_initializer())

        self.assertAllEqual(session.run(my1), v1)
        self.assertAllEqual(session.run(my1.initialized_value()), v1)
        self.assertAllClose(session.run(before), v1)
        self.assertAllClose(session.run(after), v1)
        with self.assertRaises(AssertionError):
          self.assertAllClose(v1, [[0.0] * 10])
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:29,代码来源:checkpoint_utils_test.py


示例5: testNoAdditionalReadOpsForResourceVariables

  def testNoAdditionalReadOpsForResourceVariables(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, _, _, _ = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.session(graph=g) as session:
        my1 = resource_variable_ops.ResourceVariable([[0.0] * 10], name="my1")

        with ops.name_scope("init_from_checkpoint"):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1})

        # Basic sanity checks:
        session.run(variables.global_variables_initializer())
        self.assertAllEqual(session.run(my1), v1)

    ops_in_init_from_checkpoint_scope = [
        op for op in g.get_operations()
        if (op.name.startswith("init_from_checkpoint/") and
            not op.name.startswith("init_from_checkpoint/checkpoint_initializer"
                                  ) and
            op.type != "AssignVariableOp" and
            op.type != "Identity")
    ]
    self.assertEqual(ops_in_init_from_checkpoint_scope, [])
开发者ID:clsung,项目名称:tensorflow,代码行数:26,代码来源:checkpoint_utils_test.py


示例6: init_and_verify

 def init_and_verify(g):
   v1 = variable_scope.get_variable("new_var1", [1, 10])
   checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
       "var1": "new_var1",
   })
   with self.test_session(graph=g) as session:
     session.run(variables.global_variables_initializer())
     self.assertAllEqual(v1_value, self.evaluate(v1))
开发者ID:ChristinaEricka,项目名称:tensorflow,代码行数:8,代码来源:checkpoint_utils_test.py


示例7: testRestoreRunsOnSameDevice

  def testRestoreRunsOnSameDevice(self):
    checkpoint_dir = self.get_temp_dir()
    with self.cached_session() as session:
      _create_checkpoints(session, checkpoint_dir)

    with ops.Graph().as_default():
      with ops.device("/job:ps"):
        with variable_scope.variable_scope("useful_scope"):
          my4 = variable_scope.get_variable("var4", [9, 9])

      checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                            {"useful_scope/": "useful_scope/"})
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:12,代码来源:checkpoint_utils_test.py


示例8: init_and_verify

 def init_and_verify(g):
   v1 = variable_scope.get_variable("new_var1", [1, 10])
   # Use string add to create new object in each replica
   prefix = "new_"
   suffix = "var1"
   new_var1 = prefix + suffix
   checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
       "var1": new_var1,
   })
   with self.test_session(graph=g) as session:
     session.run(variables.global_variables_initializer())
     self.assertAllEqual(v1_value, self.evaluate(v1))
开发者ID:jackd,项目名称:tensorflow,代码行数:12,代码来源:checkpoint_utils_test.py


示例9: init_and_verify

 def init_and_verify(g):
   v1 = variable_scope.get_variable("new_var1", [1, 10])
   v2 = variable_scope.get_variable(
       "new_var2", [10, 10],
       synchronization=variable_scope.VariableSynchronization.ON_READ,
       aggregation=variable_scope.VariableAggregation.MEAN)
   checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
       "var1": "new_var1",
       "var2": "new_var2"
   })
   with self.session(graph=g) as session:
     session.run(variables.global_variables_initializer())
     self.assertAllEqual(v1_value, self.evaluate(v1))
     self.assertAllEqual(v2_value, self.evaluate(v2))
开发者ID:becster,项目名称:tensorflow,代码行数:14,代码来源:checkpoint_utils_test.py


示例10: testRestoreRunsOnSameDevice

  def testRestoreRunsOnSameDevice(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      _create_checkpoints(session, checkpoint_dir)

    with ops.Graph().as_default():
      with ops.device("/job:ps"):
        with variable_scope.variable_scope("useful_scope"):
          my4 = variable_scope.get_variable("var4", [9, 9])

      checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                            {"useful_scope/": "useful_scope/"})
      # initializer runs on the same task but always on CPU.
      self.assertEqual(my4._initializer_op.op.inputs[1].device,
                       "/job:ps/device:CPU:0")
开发者ID:QiangCai,项目名称:tensorflow,代码行数:15,代码来源:checkpoint_utils_test.py


示例11: testInitWithScopeDoesNotCaptureSuffixes

  def testInitWithScopeDoesNotCaptureSuffixes(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      _, _, _, v4 = _create_checkpoints(session, checkpoint_dir)

    with ops.Graph().as_default() as g:
      with variable_scope.variable_scope("useful_scope"):
        my4 = variable_scope.get_variable("var4", [9, 9])
      with variable_scope.variable_scope("useful_scope_1"):
        my5_init = [[1.0, 2.0], [3.0, 4.0]]
        my5 = variable_scope.get_variable("var5", initializer=my5_init)

      checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                            {"useful_scope/": "useful_scope/"})
      with self.test_session(graph=g) as session:
        session.run(variables.global_variables_initializer())
        self.assertAllEqual(my4.eval(session), v4)
        self.assertAllEqual(my5.eval(session), my5_init)
开发者ID:QiangCai,项目名称:tensorflow,代码行数:18,代码来源:checkpoint_utils_test.py


示例12: testInitToRootCheckpoint

  def testInitToRootCheckpoint(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        my1 = variable_scope.get_variable("var1", [1, 10])
        my2 = variable_scope.get_variable("var2", [10, 10])
        my3 = variable_scope.get_variable("var3", [100, 100])
        with variable_scope.variable_scope("useful_scope"):
          my4 = variable_scope.get_variable("var4", [9, 9])

        checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                              {"/": "/",})

        session.run(variables.global_variables_initializer())
        self.assertAllEqual(my1.eval(session), v1)
        self.assertAllEqual(my2.eval(session), v2)
        self.assertAllEqual(my3.eval(session), v3)
        self.assertAllEqual(my4.eval(session), v4)
开发者ID:QiangCai,项目名称:tensorflow,代码行数:22,代码来源:checkpoint_utils_test.py


示例13: testInitFromCheckpointMissing

  def testInitFromCheckpointMissing(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      _, _, _, _ = _create_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          _ = variable_scope.get_variable("my1", [10, 10])
          _ = variable_scope.get_variable(
              "my2", [1, 10],
              dtype=dtypes.int64,
              initializer=init_ops.zeros_initializer())

        # No directory.
        with self.assertRaises(errors_impl.OpError):
          checkpoint_utils.init_from_checkpoint("no_dir",
                                                {"var1": "some_scope/my1"})

        # No variable in checkpoint.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                {"no_var": "some_scope/my1"})

        # No variable in the graph.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                {"var3": "some_scope/no_var"})

        # Shape mismatch.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                {"var1": "some_scope/my1"})

        # Variable 'my1' and 'my2' are missing in given checkpoint scope.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(
              checkpoint_dir, {"useful_scope/": "some_scope/"})

        # Mapping is not to scope name.
        with self.assertRaises(ValueError):
          checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                                {"useful_scope": "some_scope/"})
开发者ID:QiangCai,项目名称:tensorflow,代码行数:44,代码来源:checkpoint_utils_test.py


示例14: testInitFromPartitionVar

  def testInitFromPartitionVar(self):
    checkpoint_dir = self.get_temp_dir()
    with self.test_session() as session:
      v1 = _create_partition_checkpoints(session, checkpoint_dir)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          my1 = variable_scope.get_variable(
              name="my1",
              shape=[100, 100],
              initializer=init_ops.zeros_initializer(),
              partitioner=partitioned_variables.min_max_variable_partitioner(
                  max_partitions=5, axis=0, min_slice_size=8 << 10))
          my1_var_list = my1._get_variable_list()
        # Create another variable with different partitions than the variable in
        # the checkpoint.
        with variable_scope.variable_scope("some_other_scope"):
          my2 = variable_scope.get_variable(
              name="var1",
              shape=[100, 100],
              initializer=init_ops.zeros_initializer(),
              partitioner=partitioned_variables.min_max_variable_partitioner(
                  max_partitions=5, axis=0, min_slice_size=16 << 10))
          my2_var_list = my2._get_variable_list()

        checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
            "scope/var1": "some_scope/my1",
            "scope/": "some_other_scope/"})

        session.run(variables.global_variables_initializer())
        my1_values = session.run(my1_var_list)
        self.assertAllEqual(my1_values, v1)
        my2_values = session.run(my2_var_list)
        # Verify we created different number of partitions.
        self.assertNotEquals(len(my2_values), len(v1))
        # Verify the values were correctly initialized inspite of different
        # partitions.
        full_my2_values = np.concatenate(my2_values, axis=0)
        full_v1_values = np.concatenate(v1, axis=0)
        self.assertAllEqual(full_my2_values, full_v1_values)

    # New graph and session.
    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as session:
        with variable_scope.variable_scope("some_scope"):
          my1 = variable_scope.get_variable(
              name="my1",
              shape=[100, 100],
              initializer=init_ops.truncated_normal_initializer(0.5),
              partitioner=partitioned_variables.min_max_variable_partitioner(
                  max_partitions=5, axis=0, min_slice_size=8 << 10))
          my1_var_list = my1._get_variable_list()

        checkpoint_utils.init_from_checkpoint(checkpoint_dir,
                                              {"scope/var1": my1_var_list,})

        session.run(variables.global_variables_initializer())
        my1_values = session.run(my1_var_list)
        self.assertAllEqual(my1_values, v1)
开发者ID:QiangCai,项目名称:tensorflow,代码行数:61,代码来源:checkpoint_utils_test.py


示例15: warm_start


#.........这里部分代码省略.........
      Defaults to `'.*'`, which warm-starts all variables in the
      TRAINABLE_VARIABLES collection.  Note that this excludes variables such
      as accumulators and moving statistics from batch norm.
    var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
      `tf.estimator.VocabInfo`. The variable names should be "full" variables,
      not the names of the partitions.  If not explicitly provided, the variable
      is assumed to have no (changes to) vocabulary.
    var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
      name of the previously-trained variable in `ckpt_to_initialize_from`. If
      not explicitly provided, the name of the variable is assumed to be same
      between previous checkpoint and current model.  Note that this has no
      effect on the set of variables that is warm-started, and only controls
      name mapping (use `vars_to_warm_start` for controlling what variables to
      warm-start).
  Raises:
    ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
      configuration for variable names that are not used.  This is to ensure
      a stronger check for variable configuration than relying on users to
      examine the logs.
  """
  if var_name_to_vocab_info is None:
    var_name_to_vocab_info = {}
  if var_name_to_prev_var_name is None:
    var_name_to_prev_var_name = {}
  logging.info("Warm-starting from: %s", (ckpt_to_initialize_from,))
  grouped_variables = _get_grouped_variables(vars_to_warm_start)

  # Keep track of which var_names in var_name_to_prev_var_name and
  # var_name_to_vocab_info have been used.  Err on the safer side by throwing an
  # exception if any are unused by the end of the loop.  It is easy to misname
  # a variable during this configuration, in which case without this check, we
  # would fail to warm-start silently.
  prev_var_name_used = set()
  vocab_info_used = set()

  # Group the vocabless vars into one call to init_from_checkpoint.
  vocabless_vars = {}
  for var_name, variable in six.iteritems(grouped_variables):
    prev_var_name = var_name_to_prev_var_name.get(var_name)
    if prev_var_name:
      prev_var_name_used.add(var_name)
    vocab_info = var_name_to_vocab_info.get(var_name)
    if vocab_info:
      vocab_info_used.add(var_name)
      logging.info(
          "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}"
          " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}"
          " initializer: {}".format(
              var_name,
              vocab_info.new_vocab,
              vocab_info.new_vocab_size,
              vocab_info.old_vocab,
              (vocab_info.old_vocab_size if vocab_info.old_vocab_size > 0
               else "All"),
              vocab_info.num_oov_buckets,
              prev_var_name or "Unchanged",
              vocab_info.backup_initializer or "zero-initialized"))
      _warm_start_var_with_vocab(
          variable,
          current_vocab_path=vocab_info.new_vocab,
          current_vocab_size=vocab_info.new_vocab_size,
          prev_ckpt=ckpt_to_initialize_from,
          prev_vocab_path=vocab_info.old_vocab,
          previous_vocab_size=vocab_info.old_vocab_size,
          current_oov_buckets=vocab_info.num_oov_buckets,
          prev_tensor_name=prev_var_name,
          initializer=vocab_info.backup_initializer,
          axis=vocab_info.axis)
    else:
      # For the special value of vars_to_warm_start = None,
      # we only warm-start variables with explicitly specified vocabularies.
      if vars_to_warm_start:
        logging.info("Warm-starting variable: {}; prev_var_name: {}".format(
            var_name, prev_var_name or "Unchanged"))
        # Because we use a default empty list in grouped_variables, single
        # unpartitioned variables will be lists here, which we rectify in order
        # for init_from_checkpoint logic to work correctly.
        if len(variable) == 1:
          variable = variable[0]
        prev_tensor_name, var = _get_var_info(variable, prev_var_name)
        vocabless_vars[prev_tensor_name] = var

  checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from, vocabless_vars)
  prev_var_name_not_used = set(
      var_name_to_prev_var_name.keys()) - prev_var_name_used
  vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used

  if prev_var_name_not_used:
    raise ValueError(
        "You provided the following variables in "
        "var_name_to_prev_var_name that were not used: "
        "{0}.  Perhaps you misspelled them?  Here is the list of viable "
        "variable names: {1}".format(prev_var_name_not_used,
                                     grouped_variables.keys()))
  if vocab_info_not_used:
    raise ValueError(
        "You provided the following variables in "
        "var_name_to_vocab_info that were not used: {0}. "
        " Perhaps you misspelled them?  Here is the list of viable variable "
        "names: {1}".format(vocab_info_not_used, grouped_variables.keys()))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:101,代码来源:warm_starting_util.py


示例16: _warmstart_var

def _warmstart_var(var, prev_ckpt, prev_tensor_name=None):
  """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.

  Args:
    var: Current graph's variable that needs to be warm-started (initialized).
      Can be either of the following:
      (i) `Variable`
      (ii) `ResourceVariable`
      (iii) `PartitionedVariable`
      (iv) list of `Variable` and/or `PartitionedVariable`: The list may
        contain one or more variables that has been sharded.  For example:
        [Variable('a/part_0'), Variable('b/part_0'), Variable('a/part_1'),
         PartitionedVariable([Variable('c/part_0'), Variable('c/part_1')])]
        where we have three whole Variables represented ('a', 'b', and 'c').
    prev_ckpt: A string specifying the directory with checkpoint file(s) or path
      to checkpoint. The given checkpoint must have tensor with name
      `prev_tensor_name` (if not None) or tensor with name same as given `var`.
    prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
      None, we lookup tensor with same name as given `var`.

  Raises:
    ValueError: If prev_tensor_name is not None, but the given var represents
      more than one Variable.
    TypeError: If var is not one of the allowed types.
  """
  if _is_variable(var):
    current_var_name = _infer_var_name([var])
  elif isinstance(var, variables.PartitionedVariable):
    current_var_name = _infer_var_name([var])
    var = var._get_variable_list()  # pylint: disable=protected-access
  elif (isinstance(var, list) and all(
      _is_variable(v) or isinstance(v, variables.PartitionedVariable)
      for v in var)):
    # Convert length-1 lists of vars to single tf.Variables.  This ensures that
    # checkpoint_utils.init_from_checkpoint() doesn't incorrectly assume
    # slice info is present.
    if len(var) == 1:
      current_var_name = _infer_var_name(var)
      var = var[0]
    else:
      # If we have multiple elements in var, we cannot assume they all
      # represent the same Variable.
      name_to_var_dict = saver.BaseSaverBuilder.OpListToDict(
          var, convert_variable_to_tensor=False)
      if prev_tensor_name:
        # Providing a prev_tensor_name is only viable if var representes a
        # single Variable.
        if len(name_to_var_dict) > 1:
          raise ValueError("var represented more than one Variable, but "
                           "prev_tensor_name was provided.")
        checkpoint_utils.init_from_checkpoint(prev_ckpt, {
            prev_tensor_name: var
        })
      else:
        # OpListToDict gives us roughly what we need, but
        # the values in the dict may be PartitionedVariables (which
        # init_from_checkpoint does not expect) that we need to convert to
        # lists.
        name_to_var_dict_fixed = {}
        for name, var in six.iteritems(name_to_var_dict):
          if isinstance(var, variables.PartitionedVariable):
            name_to_var_dict_fixed[name] = var._get_variable_list()  # pylint: disable=protected-access
          else:
            name_to_var_dict_fixed[name] = var
        checkpoint_utils.init_from_checkpoint(prev_ckpt, name_to_var_dict_fixed)
      return
  else:
    raise TypeError(
        "var MUST be one of the following: a Variable, PartitionedVariable, or "
        "list of Variable's and/or PartitionedVariable's, but is {}".format(
            type(var)))
  if not prev_tensor_name:
    # Assume tensor name remains the same.
    prev_tensor_name = current_var_name
  checkpoint_utils.init_from_checkpoint(prev_ckpt, {prev_tensor_name: var})
开发者ID:marcomarchesi,项目名称:tensorflow,代码行数:75,代码来源:warm_starting_util.py



注:本文中的tensorflow.python.training.checkpoint_utils.init_from_checkpoint函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python util.gather_initializers函数代码示例发布时间:2022-05-27
下一篇:
Python checkpoint_management.latest_checkpoint函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap