• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python loader.load函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.saved_model.loader.load函数的典型用法代码示例。如果您正苦于以下问题:Python load函数的具体用法?Python load怎么用?Python load使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了load函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testSaveAsText

  def testSaveAsText(self):
    export_dir = os.path.join(
        compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("astext"))
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with a single variable. SavedModel invoked to:
    # - add with weights.
    with self.test_session(graph=tf.Graph()) as sess:
      v = tf.Variable(42, name="v")
      sess.run(tf.initialize_all_variables())
      self.assertEqual(42, v.eval())
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Graph with the same single variable. SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=tf.Graph()) as sess:
      v = tf.Variable(43, name="v")
      sess.run(tf.initialize_all_variables())
      self.assertEqual(43, v.eval())
      builder.add_meta_graph(["bar"])

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Restore the graph with tag "foo", whose variables were saved.
    with self.test_session(graph=tf.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(42, tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())

    # Restore the graph with tag "bar", whose variables were not saved.
    with self.test_session(graph=tf.Graph()) as sess:
      loader.load(sess, ["bar"], export_dir)
      self.assertEqual(42, tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())
开发者ID:apollos,项目名称:tensorflow,代码行数:33,代码来源:saved_model_test.py


示例2: testLegacyInitOp

  def testLegacyInitOp(self):
    export_dir = self._get_export_dir("test_legacy_init_op")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      # Add `v1` and `v2` variables to the graph.
      v1 = variables.Variable(1, name="v1")
      ops.add_to_collection("v", v1)
      v2 = variables.Variable(2, name="v2")
      ops.add_to_collection("v", v2)

      # Initialize another variable `v3` to 42.
      v3 = variables.Variable(42, name="v3", trainable=False, collections=[])
      ops.add_to_collection("v", v3)

      # Set up an assignment op to be run as part of the legacy_init_op.
      assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
      legacy_init_op = control_flow_ops.group(assign_v3, name="legacy_init_op")

      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph_and_variables(
          sess, ["foo"], legacy_init_op=legacy_init_op)

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(1, ops.get_collection("v")[0].eval())
      self.assertEqual(2, ops.get_collection("v")[1].eval())
      # Evaluates to the sum of the first two variables and assigned as part of
      # the legacy_init_op, following a restore.
      self.assertEqual(3, ops.get_collection("v")[2].eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:33,代码来源:saved_model_test.py


示例3: testSaveAsText

  def testSaveAsText(self):
    export_dir = self._get_export_dir("test_astext")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with a single variable. SavedModel invoked to:
    # - add with weights.
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Graph with the same single variable. SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 43)
      builder.add_meta_graph(["bar"])

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Restore the graph with tag "foo", whose variables were saved.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

    # Restore the graph with tag "bar", whose variables were not saved.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["bar"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:30,代码来源:saved_model_test.py


示例4: testCustomMainOp

  def testCustomMainOp(self):
    export_dir = self._get_export_dir("test_main_op")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      # Add `v1` and `v2` variables to the graph.
      v1 = variables.Variable(1, name="v1")
      ops.add_to_collection("v", v1)
      v2 = variables.Variable(2, name="v2")
      ops.add_to_collection("v", v2)

      # Initialize another variable `v3` to 42.
      v3 = variables.Variable(42, name="v3")
      ops.add_to_collection("v", v3)

      # Set up an assignment op to be run as part of the main_op.
      with ops.control_dependencies([main_op.main_op()]):
        add_v1_v2 = math_ops.add(v1._ref(), v2._ref())
        custom_main_op = control_flow_ops.group(state_ops.assign(v3, add_v1_v2))

      sess.run(custom_main_op)
      builder.add_meta_graph_and_variables(
          sess, ["foo"], main_op=custom_main_op)

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(1, ops.get_collection("v")[0].eval())
      self.assertEqual(2, ops.get_collection("v")[1].eval())
      # Evaluates to the sum of the first two variables and assigned as part of
      # the main_op, following a restore.
      self.assertEqual(3, ops.get_collection("v")[2].eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:34,代码来源:saved_model_test.py


示例5: testTrainOpGroup

  def testTrainOpGroup(self):
    export_dir = self._get_export_dir("test_train_op_group")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      # Add `v1` and `v2` variables to the graph.
      v1 = variables.Variable(1, name="v1")
      ops.add_to_collection("v", v1)
      v2 = variables.Variable(2, name="v2")
      ops.add_to_collection("v", v2)

      sess.run(variables.global_variables_initializer())
      train_op = control_flow_ops.group()

      sess.run(train_op)
      # TODO(karmel): remove explicit call when in the public method.
      builder._add_train_op(train_op)
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertEqual(1, ops.get_collection("v")[0].eval())
      self.assertEqual(2, ops.get_collection("v")[1].eval())
      self.assertIsInstance(
          ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Operation)
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:28,代码来源:saved_model_test.py


示例6: testTrainOpAfterVariables

  def testTrainOpAfterVariables(self):
    export_dir = self._get_export_dir("test_train_op_after_variables")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      # Add `v1` and `v2` variables to the graph.
      v1 = variables.Variable(1, name="v1")
      ops.add_to_collection("v", v1)
      v2 = variables.Variable(2, name="v2")
      ops.add_to_collection("v", v2)

      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph_and_variables(sess, ["pre_foo"])

      train_op = state_ops.assign_add(v1, v2)
      sess.run(train_op)
      # TODO(karmel): remove explicit call when in the public method.
      builder._add_train_op(train_op)
      builder.add_meta_graph(["foo"])

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      self.assertIsInstance(
          ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor)

    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["pre_foo"], export_dir)
      self.assertFalse(ops.get_collection(constants.TRAIN_OP_KEY))
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:31,代码来源:saved_model_test.py


示例7: testCustomSaveable

  def testCustomSaveable(self):
    export_dir = self._get_export_dir("custom_saveable")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with session.Session(
        graph=ops.Graph(),
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      # CheckpointedOp is a key-value table that can be saved across sessions.
      # The table register itself in SAVEABLE_OBJECTS collection.
      v1 = saver_test_utils.CheckpointedOp(name="v1")
      variables.global_variables_initializer().run()
      v1.insert("k1", 3.0).run()
      # Once the table is restored, we can access it through this reference.
      ops.add_to_collection("table_ref", v1.table_ref)
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Save the SavedModel to disk.
    builder.save()

    with session.Session(
        graph=ops.Graph(),
        config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
      loader.load(sess, ["foo"], export_dir)
      # Instantiate a wrapper object from the checkpointed reference.
      v1 = saver_test_utils.CheckpointedOp(
          name="v1", table_ref=ops.get_collection("table_ref")[0])
      self.assertEqual(b"k1", v1.keys().eval())
      self.assertEqual(3.0, v1.values().eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:28,代码来源:saved_model_test.py


示例8: testGraphWithoutVariables

  def testGraphWithoutVariables(self):
    export_dir = self._get_export_dir("test_graph_has_variables")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with no variables.
    with self.test_session(graph=ops.Graph()) as sess:
      constant_5_name = constant_op.constant(5.0).name
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Second graph with no variables
    with self.test_session(graph=ops.Graph()) as sess:
      constant_6_name = constant_op.constant(6.0).name
      builder.add_meta_graph(["bar"])

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph with tag "foo".
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      # Read the constant a from the graph.
      a = ops.get_default_graph().get_tensor_by_name(constant_5_name)
      b = constant_op.constant(6.0)
      c = a * b
      self.assertEqual(30.0, sess.run(c))

    # Restore the graph with tag "bar".
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["bar"], export_dir)
      # Read the constant a from the graph.
      a = ops.get_default_graph().get_tensor_by_name(constant_6_name)
      b = constant_op.constant(5.0)
      c = a * b
      self.assertEqual(30.0, sess.run(c))
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:34,代码来源:saved_model_test.py


示例9: export_fn

 def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None):
   """A wrapper to export to SavedModel, and convert it to other formats."""
   result_dir = base_strategy.export(estimator, export_dir,
                                     checkpoint_path,
                                     eval_result)
   with ops.Graph().as_default() as graph:
     with tf_session.Session(graph=graph) as sess:
       saved_model_loader.load(
           sess, [tag_constants.SERVING], result_dir)
       # Note: This is GTFlow internal API and might change.
       ensemble_model = graph.get_operation_by_name(
           "ensemble_model/TreeEnsembleSerialize")
       _, dfec_str = sess.run(ensemble_model.outputs)
       dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
       dtec.ParseFromString(dfec_str)
       # Export the result in the same folder as the saved model.
       if convert_fn:
         convert_fn(dtec, sorted_feature_names,
                    len(dense_floats),
                    len(sparse_float_indices),
                    len(sparse_int_indices), result_dir, eval_result)
       feature_importances = _get_feature_importances(
           dtec, sorted_feature_names,
           len(dense_floats),
           len(sparse_float_indices), len(sparse_int_indices))
       sorted_by_importance = sorted(
           feature_importances.items(), key=lambda x: -x[1])
       assets_dir = os.path.join(result_dir, "assets.extra")
       gfile.MakeDirs(assets_dir)
       with gfile.GFile(os.path.join(assets_dir, "feature_importances"),
                        "w") as f:
         f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance))
   return result_dir
开发者ID:jiayouwyhit,项目名称:tensorflow,代码行数:33,代码来源:custom_export_strategy.py


示例10: testVariables

  def testVariables(self):
    export_dir = os.path.join(
        compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("variables"))
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with two variables. SavedModel invoked to:
    # - add with weights.
    with self.test_session(graph=tf.Graph()) as sess:
      v1 = tf.Variable(1, name="v1")
      v2 = tf.Variable(2, name="v2")
      sess.run(tf.initialize_all_variables())
      self.assertEqual(1, v1.eval())
      self.assertEqual(2, v2.eval())
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Graph with a single variable (subset of the variables from the previous
    # graph whose weights were saved). SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=tf.Graph()) as sess:
      v2 = tf.Variable(3, name="v2")
      sess.run(tf.initialize_all_variables())
      self.assertEqual(3, v2.eval())
      builder.add_meta_graph(["bar"])

    # Graph with a single variable (disjoint set of variables from the previous
    # graph whose weights were saved). SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=tf.Graph()) as sess:
      v3 = tf.Variable(4, name="v3")
      sess.run(tf.initialize_all_variables())
      self.assertEqual(4, v3.eval())
      builder.add_meta_graph(["baz"])

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph with tag "foo", whose variables were saved.
    with self.test_session(graph=tf.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      collection_vars = tf.get_collection(tf.GraphKeys.VARIABLES)
      self.assertEqual(len(collection_vars), 2)
      self.assertEqual(1, collection_vars[0].eval())
      self.assertEqual(2, collection_vars[1].eval())

    # Restore the graph with tag "bar", whose variables were not saved. Only the
    # subset of the variables added to the graph will be restored with the
    # checkpointed value.
    with self.test_session(graph=tf.Graph()) as sess:
      loader.load(sess, ["bar"], export_dir)
      collection_vars = tf.get_collection(tf.GraphKeys.VARIABLES)
      self.assertEqual(len(collection_vars), 1)
      self.assertEqual(2, collection_vars[0].eval())

    # Try restoring the graph with tag "baz", whose variables were not saved.
    # Since this graph has a disjoint set of variables from the set that was
    # saved, this should raise an error.
    with self.test_session(graph=tf.Graph()) as sess:
      self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"],
                        export_dir)
开发者ID:apollos,项目名称:tensorflow,代码行数:59,代码来源:saved_model_test.py


示例11: testClearExtraneousSavers

  def testClearExtraneousSavers(self):
    export_dir = os.path.join(test.get_temp_dir(),
                              "test_clear_extraneous_savers")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Create a variable and a Saver.
    with ops.Graph().as_default() as graph:
      with session.Session(
          target="",
          config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
        self._init_and_validate_variable(sess, "v", 42)

        # Add two Savers, which should be removed in
        # add_meta_graph_and_variables() in favor of the locally added one.
        saver1 = tf_saver.Saver()
        graph.add_to_collection(ops.GraphKeys.SAVERS, saver1)
        saver2 = tf_saver.Saver()
        graph.add_to_collection(ops.GraphKeys.SAVERS, saver2)

        # Confirm there are two SaverDefs.
        savers = graph.get_collection(ops.GraphKeys.SAVERS)
        self.assertEqual(2, len(savers))

        # Confirm there are two Save and two Restore ops.
        save_op_names = set([x.name for x in graph.get_operations()
                             if x.type == "SaveV2"])
        self.assertSetEqual(set(["save/SaveV2", "save_1/SaveV2"]),
                            save_op_names)

        restore_op_names = set([x.name for x in graph.get_operations()
                                if x.type == "RestoreV2"])
        self.assertSetEqual(set(["save/RestoreV2", "save_1/RestoreV2"]),
                            restore_op_names)

        # The SavedModel builder adds its own Saver' for a total of three.
        builder.add_meta_graph_and_variables(
            sess, [tag_constants.TRAINING], clear_devices=True)

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph.
    with ops.Graph().as_default() as graph:
      with self.test_session(graph=graph) as sess:
        loader.load(sess, [tag_constants.TRAINING], export_dir)
        self.assertEqual(
            42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

        # Confirm that the reloaded graph has only one SaverDef.
        savers = ops.get_collection(ops.GraphKeys.SAVERS)
        self.assertEqual(1, len(savers))

        # The reloaded graph should have exactly one Save and one Restore op.
        save_op_names = set([x.name for x in graph.get_operations()
                             if x.type == "SaveV2"])
        self.assertSetEqual(set(["save_2/SaveV2"]), save_op_names)
        restore_op_names = set([x.name for x in graph.get_operations()
                                if x.type == "RestoreV2"])
        self.assertSetEqual(set(["save_2/RestoreV2"]), restore_op_names)
开发者ID:adityaatluri,项目名称:tensorflow,代码行数:59,代码来源:saved_model_test.py


示例12: testSignatureDefs

  def testSignatureDefs(self):
    export_dir = self._get_export_dir("test_signature_defs")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with a single variable and a single entry in the signature def map.
    # SavedModel is invoked to add with weights.
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)
      # Build and populate an empty SignatureDef for testing.
      foo_signature = signature_def_utils.build_signature_def(dict(),
                                                              dict(), "foo")
      builder.add_meta_graph_and_variables(
          sess, ["foo"], signature_def_map={"foo_key": foo_signature})

    # Graph with the same single variable and multiple entries in the signature
    # def map. No weights are saved by SavedModel.
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 43)
      # Build and populate a different SignatureDef for testing.
      bar_signature = signature_def_utils.build_signature_def(dict(),
                                                              dict(), "bar")
      # Also, build a different SignatureDef corresponding to "foo_key" defined
      # in the previous graph.
      foo_new_signature = signature_def_utils.build_signature_def(dict(),
                                                                  dict(),
                                                                  "foo_new")
      builder.add_meta_graph(
          ["bar"],
          signature_def_map={
              "bar_key": bar_signature,
              "foo_key": foo_new_signature
          })

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph with tag "foo". The single entry in the SignatureDef map
    # corresponding to "foo_key" should exist.
    with self.test_session(graph=ops.Graph()) as sess:
      foo_graph = loader.load(sess, ["foo"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

      foo_signature = foo_graph.signature_def
      self.assertEqual(len(foo_signature), 1)
      self.assertEqual("foo", foo_signature["foo_key"].method_name)

    # Restore the graph with tag "bar". The SignatureDef map should have two
    # entries. One corresponding to "bar_key" and another corresponding to the
    # new value of "foo_key".
    with self.test_session(graph=ops.Graph()) as sess:
      bar_graph = loader.load(sess, ["bar"], export_dir)
      self.assertEqual(
          42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())

      bar_signature = bar_graph.signature_def
      self.assertEqual(len(bar_signature), 2)
      self.assertEqual("bar", bar_signature["bar_key"].method_name)
      self.assertEqual("foo_new", bar_signature["foo_key"].method_name)
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:59,代码来源:saved_model_test.py


示例13: testStripDefaultAttrsInconsistentConsumerDefaults

  def testStripDefaultAttrsInconsistentConsumerDefaults(self):
    if ops._USE_C_API: return  # TODO(skyewm): get this working

    export_dir = self._get_export_dir(
        "test_strip_default_attrs_no_consumer_defaults")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Add a graph with two float32 variables and a Complex Op composing them
    # with strip_default_attrs enabled. This must remove the following
    # defaults for the "Complex" Op:
    #   o "T"    : float32.   (input type)
    #   o "Tout" : complex64. (output type)
    with session.Session(graph=ops.Graph()) as sess:
      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")
      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph_and_variables(
          sess, ["foo"], strip_default_attrs=True)

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Update the Op registry to remove defaults for all attrs("T", "Tout") from
    # the "Complex" OpDef.
    complex_op_def = op_def_registry.get_registered_ops()["Complex"]
    original_complex_op_def = op_def_pb2.OpDef()
    original_complex_op_def.CopyFrom(complex_op_def)
    for attr_def in complex_op_def.attr:
      attr_def.ClearField("default_value")

    # Loading the SavedModel via the loader must fail because the SavedModel
    # does not have any attr values for the "Complex" node and the current
    # op registry does not have have any default values for the "Complex" op.
    sess = session.Session(graph=ops.Graph())
    with self.assertRaisesRegexp(
        ValueError,
        "Expected one attr with name .*T(out)?.* in name: \"complex\".*"):
      loader.load(sess, ["foo"], export_dir)

    # Update the Op registry to change the defaults for attr "Tout"
    # (complex64 -> complex128).
    complex_op_def.CopyFrom(original_complex_op_def)
    for attr_def in complex_op_def.attr:
      if attr_def.name == "Tout":
        attr_def.default_value.type = types_pb2.DT_COMPLEX128

    # Loading the SavedModel via the loader must set "Tout" attr_value for the
    # "Complex" node according to the latest defaults (complex128). This is
    # expected to fail the model import as there is no OpKernel registered to
    # handle attrs "T" (float32) and "Tout" (complex128).
    sess = session.Session(graph=ops.Graph())
    with self.assertRaisesRegexp(
        errors.InvalidArgumentError,
        ".*No OpKernel was registered to support Op \'Complex\' with these "
        "attrs..*"):
      loader.load(sess, ["foo"], export_dir)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:57,代码来源:saved_model_test.py


示例14: freeze_saved_model

def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
                       output_arrays, tag_set, signature_key):
  """Converts a SavedModel to a frozen graph.

  Args:
    saved_model_dir: SavedModel directory to convert.
    input_arrays: List of input tensors to freeze graph with. Uses input arrays
      from SignatureDef when none are provided.
    input_shapes: Dict of strings representing input tensor names to list of
      integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
      Automatically determined when input shapes is None (e.g., {"foo" : None}).
    output_arrays: List of output tensors to freeze graph with. Uses output
      arrays from SignatureDef when none are provided.
    tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
      analyze. All tags in the tag set must be present.
    signature_key: Key identifying SignatureDef containing inputs and outputs.

  Returns:
    frozen_graph_def: Frozen GraphDef.
    in_tensors: List of input tensors for the graph.
    out_tensors: List of output tensors for the graph.

  Raises:
    ValueError:
      SavedModel doesn't contain a MetaGraphDef identified by tag_set.
      signature_key is not in the MetaGraphDef.
      assets/ directory is in the MetaGraphDef.
      input_shapes does not match the length of input_arrays.
      input_arrays or output_arrays are not valid.
  """
  # Read SignatureDef.
  meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
  signature_def = _get_signature_def(meta_graph, signature_key)
  inputs, outputs = _get_inputs_outputs(signature_def)

  # Check SavedModel for assets directory.
  collection_def = meta_graph.collection_def
  if constants.ASSETS_KEY in collection_def:
    raise ValueError("SavedModels with assets/ directory are not supported.")

  graph = ops.Graph()
  with session.Session(graph=graph) as sess:
    loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)

    # Gets input and output tensors.
    # TODO(zhixianyan): Use TFLite supported Op list to filter outputs.
    in_tensors = _get_tensors(graph, inputs, input_arrays)
    out_tensors = _get_tensors(graph, outputs, output_arrays)
    set_tensor_shapes(in_tensors, input_shapes)

    output_names = [node.split(":")[0] for node in outputs]
    frozen_graph_def = tf_graph_util.convert_variables_to_constants(
        sess, graph.as_graph_def(), output_names)

    return frozen_graph_def, in_tensors, out_tensors
开发者ID:AnishShah,项目名称:tensorflow,代码行数:55,代码来源:convert_saved_model.py


示例15: _TestStaticOp

  def _TestStaticOp(self, use_function_backup):
    if not is_tensorrt_enabled():
      return

    tmp_dir = self.get_temp_dir()
    input_saved_model_dir = os.path.join(tmp_dir, "in_dir3")
    output_saved_model_dir = os.path.join(tmp_dir, "out_dir3")
    self._WriteInputSavedModel(input_saved_model_dir)
    output_graph_def = self._ConvertGraph(
        input_saved_model_dir=input_saved_model_dir,
        output_saved_model_dir=output_saved_model_dir,
        maximum_cached_engines=2,  # This is noop, added just for testing.
        use_function_backup=use_function_backup)

    # Test the output GraphDef.
    with ops.Graph().as_default():
      importer.import_graph_def(output_graph_def, name="")
      with self.session(config=self._GetConfigProto()) as sess:
        # Run with batch size 1, the default engine embedded in the graphdef
        # will be used.
        self._TestRun(
            sess,
            1,
            use_function_backup=use_function_backup,
            expect_engine_is_run=True)
        # Run with batch size 2, which exceed the max_batch_size, it should try
        # to fall back to TF function.
        self._TestRun(
            sess,
            2,
            use_function_backup=use_function_backup,
            expect_engine_is_run=False)

    # Test the output SavedModel
    with ops.Graph().as_default():
      with self.session(config=self._GetConfigProto()) as sess:
        loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
        # Run with batch size 1, the default engine embedded in the graphdef
        # will be used.
        self._TestRun(
            sess,
            1,
            use_function_backup=use_function_backup,
            expect_engine_is_run=True)
        # Run with batch size 2, which exceed the max_batch_size, it should try
        # to fall back to TF function.
        self._TestRun(
            sess,
            2,
            use_function_backup=use_function_backup,
            expect_engine_is_run=False)
开发者ID:perfmjs,项目名称:tensorflow,代码行数:51,代码来源:trt_convert_test.py


示例16: __init__

  def __init__(self,
               export_dir,
               signature_def_key=None,
               signature_def=None,
               input_names=None,
               output_names=None,
               tags=None,
               graph=None):
    """Initialize a `CoreEstimatorPredictor`.

    Args:
      export_dir: a path to a directory containing a `SavedModel`.
      signature_def_key: Optional string specifying the signature to use. If
        `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. Only one of
        `signature_def_key` and `signature_def` should be specified.
      signature_def: A `SignatureDef` proto specifying the inputs and outputs
        for prediction. Only one of `signature_def_key` and `signature_def`
        should be specified.
      input_names: A dictionary mapping strings to `Tensor`s in the `SavedModel`
        that represent the input. The keys can be any string of the user's
        choosing.
      output_names: A dictionary mapping strings to `Tensor`s in the
        `SavedModel` that represent the output. The keys can be any string of
        the user's choosing.
      tags: Optional. Tags that will be used to retrieve the correct
        `SignatureDef`. Defaults to `DEFAULT_TAGS`.
      graph: Optional. The Tensorflow `graph` in which prediction should be
        done.
    Raises:
      ValueError: If more than one of signature_def_key OR signature_def OR
        (input_names AND output_names) is specified.
    """
    _check_signature_arguments(
        signature_def_key, signature_def, input_names, output_names)
    tags = tags or DEFAULT_TAGS
    self._graph = graph or ops.Graph()

    with self._graph.as_default():
      self._session = session.Session()
      loader.load(self._session, tags.split(','), export_dir)

    if input_names is None:
      if signature_def is None:
        signature_def = _get_signature_def(signature_def_key, export_dir, tags)
      input_names = {k: v.name for k, v in signature_def.inputs.items()}
      output_names = {k: v.name for k, v in signature_def.outputs.items()}

    self._feed_tensors = {k: self._graph.get_tensor_by_name(v)
                          for k, v in input_names.items()}
    self._fetch_tensors = {k: self._graph.get_tensor_by_name(v)
                           for k, v in output_names.items()}
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:51,代码来源:saved_model_predictor.py


示例17: testVariables

  def testVariables(self):
    export_dir = self._get_export_dir("test_variables")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with two variables. SavedModel invoked to:
    # - add with weights.
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v1", 1)
      self._init_and_validate_variable(sess, "v2", 2)
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Graph with a single variable (subset of the variables from the previous
    # graph whose weights were saved). SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v2", 3)
      builder.add_meta_graph(["bar"])

    # Graph with a single variable (disjoint set of variables from the previous
    # graph whose weights were saved). SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v3", 4)
      builder.add_meta_graph(["baz"])

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph with tag "foo", whose variables were saved.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      self.assertEqual(len(collection_vars), 2)
      self.assertEqual(1, collection_vars[0].eval())
      self.assertEqual(2, collection_vars[1].eval())

    # Restore the graph with tag "bar", whose variables were not saved. Only the
    # subset of the variables added to the graph will be restored with the
    # checkpointed value.
    with self.test_session(graph=ops.Graph()) as sess:
      loader.load(sess, ["bar"], export_dir)
      collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
      self.assertEqual(len(collection_vars), 1)
      self.assertEqual(2, collection_vars[0].eval())

    # Try restoring the graph with tag "baz", whose variables were not saved.
    # Since this graph has a disjoint set of variables from the set that was
    # saved, this should raise an error.
    with self.test_session(graph=ops.Graph()) as sess:
      self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"],
                        export_dir)
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:51,代码来源:saved_model_test.py


示例18: testCollections

  def testCollections(self):
    export_dir = os.path.join(
        compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("collections"))
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Graph with a single variable added to a collection. SavedModel invoked to:
    # - add with weights.
    with self.test_session(graph=tf.Graph()) as sess:
      v = tf.Variable(42, name="v")
      tf.add_to_collection("foo_vars", v)
      sess.run(tf.initialize_all_variables())
      self.assertEqual(42, v.eval())
      builder.add_meta_graph_and_variables(sess, ["foo"])

    # Graph with the same single variable added to a different collection.
    # SavedModel invoked to:
    # - simply add the model (weights are not updated).
    with self.test_session(graph=tf.Graph()) as sess:
      v = tf.Variable(43, name="v")
      tf.add_to_collection("bar_vars", v)
      sess.run(tf.initialize_all_variables())
      self.assertEqual(43, v.eval())
      builder.add_meta_graph(["bar"])

    # Save the SavedModel to disk.
    builder.save()

    # Restore the graph with tag "foo", whose variables were saved. The
    # collection 'foo_vars' should contain a single element. The collection
    # 'bar_vars' should not be found.
    with self.test_session(graph=tf.Graph()) as sess:
      loader.load(sess, ["foo"], export_dir)
      collection_foo_vars = tf.get_collection("foo_vars")
      self.assertEqual(len(collection_foo_vars), 1)
      self.assertEqual(42, collection_foo_vars[0].eval())

      self.assertEqual(len(tf.get_collection("bar_vars")), 0)

    # Restore the graph with tag "bar", whose variables were not saved. The
    # collection-def exported as part of the meta graph def is updated to
    # reflect the new collection. The value of the variable in the
    # collection-def corresponds to the saved value (from the previous graph
    # with tag "foo").
    with self.test_session(graph=tf.Graph()) as sess:
      loader.load(sess, ["bar"], export_dir)
      collection_bar_vars = tf.get_collection("bar_vars")
      self.assertEqual(len(collection_bar_vars), 1)
      self.assertEqual(42, collection_bar_vars[0].eval())

      self.assertEqual(len(tf.get_collection("foo_vars")), 0)
开发者ID:apollos,项目名称:tensorflow,代码行数:50,代码来源:saved_model_test.py


示例19: testAssets

  def testAssets(self):
    export_dir = self._get_export_dir("test_assets")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)

      # Build an asset collection.
      ignored_filepath = os.path.join(
          compat.as_bytes(test.get_temp_dir()), compat.as_bytes("ignored.txt"))
      file_io.write_string_to_file(ignored_filepath, "will be ignored")

      asset_collection = self._build_asset_collection("hello42.txt",
                                                      "foo bar baz",
                                                      "asset_file_tensor")

      builder.add_meta_graph_and_variables(
          sess, ["foo"], assets_collection=asset_collection)

    # Save the SavedModel to disk.
    builder.save()

    with self.test_session(graph=ops.Graph()) as sess:
      foo_graph = loader.load(sess, ["foo"], export_dir)
      self._validate_asset_collection(export_dir, foo_graph.collection_def,
                                      "hello42.txt", "foo bar baz",
                                      "asset_file_tensor:0")
   

鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python save.save函数代码示例发布时间:2022-05-27
下一篇:
Python model_analyzer.profile函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap