• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python graph_io.write_graph函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.framework.graph_io.write_graph函数的典型用法代码示例。如果您正苦于以下问题:Python write_graph函数的具体用法?Python write_graph怎么用?Python write_graph使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了write_graph函数的16个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: main

def main(unused_args):
  if not gfile.Exists(FLAGS.input):
    print("Input graph file '" + FLAGS.input + "' does not exist!")
    return -1

  input_graph_def = graph_pb2.GraphDef()
  with gfile.Open(FLAGS.input, "rb") as f:
    data = f.read()
    if FLAGS.frozen_graph:
      input_graph_def.ParseFromString(data)
    else:
      text_format.Merge(data.decode("utf-8"), input_graph_def)

  output_graph_def = optimize_for_inference_lib.optimize_for_inference(
      input_graph_def,
      FLAGS.input_names.split(","),
      FLAGS.output_names.split(","),
      FLAGS.placeholder_type_enum,
      FLAGS.toco_compatible)

  if FLAGS.frozen_graph:
    f = gfile.FastGFile(FLAGS.output, "w")
    f.write(output_graph_def.SerializeToString())
  else:
    graph_io.write_graph(output_graph_def,
                         os.path.dirname(FLAGS.output),
                         os.path.basename(FLAGS.output))
  return 0
开发者ID:AnishShah,项目名称:tensorflow,代码行数:28,代码来源:optimize_for_inference.py


示例2: strip_pruning_vars

def strip_pruning_vars(checkpoint_dir, output_node_names, output_dir, filename):
  """Remove pruning-related auxiliary variables and ops from the graph.

  Accepts training checkpoints and produces a GraphDef in which the pruning vars
  and ops have been removed.

  Args:
    checkpoint_dir: Path to the checkpoints.
    output_node_names: The name of the output nodes, comma separated.
    output_dir: Directory where to write the graph.
    filename: Output GraphDef file name.

  Returns:
    None

  Raises:
    ValueError: if output_nodes_names are not provided.
  """
  if not output_node_names:
    raise ValueError(
        'Need to specify atleast 1 output node through output_node_names flag')
  output_node_names = output_node_names.replace(' ', '').split(',')

  initial_graph_def = strip_pruning_vars_lib.graph_def_from_checkpoint(
      checkpoint_dir, output_node_names)

  final_graph_def = strip_pruning_vars_lib.strip_pruning_vars_fn(
      initial_graph_def, output_node_names)
  graph_io.write_graph(final_graph_def, output_dir, filename, as_text=False)
  logging.info('\nFinal graph written to %s', os.path.join(
      output_dir, filename))
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:31,代码来源:strip_pruning_vars.py


示例3: testStripUnusedMultipleInputs

  def testStripUnusedMultipleInputs(self):
    input_graph_name = "input_graph.pb"
    output_graph_name = "output_graph.pb"

    # We'll create an input graph that multiplies two input nodes.
    with ops.Graph().as_default():
      constant_node1 = constant_op.constant(1.0, name="constant_node1")
      constant_node2 = constant_op.constant(2.0, name="constant_node2")
      input_node1 = math_ops.subtract(constant_node1, 3.0, name="input_node1")
      input_node2 = math_ops.subtract(constant_node2, 5.0, name="input_node2")
      output_node = math_ops.multiply(
          input_node1, input_node2, name="output_node")
      math_ops.add(output_node, 2.0, name="later_node")
      sess = session.Session()
      output = sess.run(output_node)
      self.assertNear(6.0, output, 0.00001)
      graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)

    # We save out the graph to disk, and then call the const conversion
    # routine.
    input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
    input_binary = False
    input_node_names = "input_node1,input_node2"
    input_node_types = [
        dtypes.float32.as_datatype_enum, dtypes.float32.as_datatype_enum
    ]
    output_binary = True
    output_node_names = "output_node"
    output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)

    strip_unused_lib.strip_unused_from_files(input_graph_path, input_binary,
                                             output_graph_path, output_binary,
                                             input_node_names,
                                             output_node_names,
                                             input_node_types)

    # Now we make sure the variable is now a constant, and that the graph still
    # produces the expected result.
    with ops.Graph().as_default():
      output_graph_def = graph_pb2.GraphDef()
      with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = importer.import_graph_def(output_graph_def, name="")

      self.assertEqual(3, len(output_graph_def.node))
      for node in output_graph_def.node:
        self.assertNotEqual("Add", node.op)
        self.assertNotEqual("Sub", node.op)
        if node.name == input_node_names:
          self.assertTrue("shape" in node.attr)

      with session.Session() as sess:
        input_node1 = sess.graph.get_tensor_by_name("input_node1:0")
        input_node2 = sess.graph.get_tensor_by_name("input_node2:0")
        output_node = sess.graph.get_tensor_by_name("output_node:0")
        output = sess.run(output_node,
                          feed_dict={input_node1: [10.0],
                                     input_node2: [-5.0]})
        self.assertNear(-50.0, output, 0.00001)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:59,代码来源:strip_unused_test.py


示例4: main

def main(unused_argv):
  # Model definition.
  g = ops.Graph()
  with g.as_default():
    images = array_ops.placeholder(
        dtypes.float32, shape=(1, None, None, 3), name='input_image')
    inception.inception_resnet_v2_base(images)

  graph_io.write_graph(g.as_graph_def(), cmd_args.graph_dir,
                       cmd_args.graph_filename)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:10,代码来源:write_inception_resnet_v2_graph.py


示例5: _WriteGraph

 def _WriteGraph(self, run_params, gdef, graph_state):
   if graph_state == GraphState.ORIGINAL:
     label = "Original"
   elif graph_state == GraphState.CALIBRATE:
     label = "CalibEngine"
   elif graph_state == GraphState.INFERENCE:
     label = "InferEngine"
   graph_name = (
       self.__class__.__name__ + "_" + run_params.test_name + "_" + label +
       ".pbtxt")
   temp_dir = os.getenv("TRT_TEST_TMPDIR", self.get_temp_dir())
   if temp_dir:
     logging.info("Writing graph to %s/%s", temp_dir, graph_name)
     graph_io.write_graph(gdef, temp_dir, graph_name)
开发者ID:omoindrot,项目名称:tensorflow,代码行数:14,代码来源:tf_trt_integration_test_base.py


示例6: freeze_graph

 def freeze_graph(sess, ckpt, output):
     print("Loading checkpoint...")
     saver = tf.train.Saver()
     saver.restore(sess, ckpt)
     print("Writing graph...")
     if not os.path.isdir("_Cache"):
         os.makedirs("_Cache")
     _dir = os.path.join("_Cache", "Model")
     saver.save(sess, _dir)
     graph_io.write_graph(sess.graph, "_Cache", "Model.pb", False)
     print("Freezing graph...")
     freeze_graph.freeze_graph(
         os.path.join("_Cache", "Model.pb"),
         "", True, os.path.join("_Cache", "Model"),
         output, "save/restore_all", "save/Const:0", "Frozen.pb", True, ""
     )
     print("Done")
开发者ID:bitores,项目名称:MachineLearning,代码行数:17,代码来源:Util.py


示例7: saveModel

    def saveModel(self, sess, outputDirectory = ""):
        from tensorflow.python.framework import graph_io
        from tensorflow.python.tools import freeze_graph

        input_graph_path = outputDirectory + "tfModel.pb"
        graph_io.write_graph(sess.graph, "./", input_graph_path)
    
        #create frozen version of graph for distribution
        input_saver_def_path = ""
        input_binary = False
        checkpoint_path = outputDirectory + "models/model.ckpt"
        output_node_names = "y_ph"
        restore_op_name = "save/restore_all"
        filename_tensor_name = "save/Const:0"
        output_graph_path = outputDirectory + "tfModel_frozen.pb"
        clear_devices = False
    
        freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                                  input_binary, checkpoint_path, output_node_names,
                                  restore_op_name, filename_tensor_name,
                                  output_graph_path, clear_devices, "")
    
        print("Frozen model (model and weights) saved in file: %s" % output_graph_path)
开发者ID:susy2015,项目名称:TopTagger,代码行数:23,代码来源:CreateModel.py


示例8: _testFreezeGraph

  def _testFreezeGraph(self, saver_write_version):

    checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
    checkpoint_state_name = "checkpoint_state"
    input_graph_name = "input_graph.pb"
    output_graph_name = "output_graph.pb"

    # We'll create an input graph that has a single variable containing 1.0,
    # and that then multiplies it by 2.
    with ops.Graph().as_default():
      variable_node = variables.VariableV1(1.0, name="variable_node")
      output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
      sess = session.Session()
      init = variables.global_variables_initializer()
      sess.run(init)
      output = sess.run(output_node)
      self.assertNear(2.0, output, 0.00001)
      saver = saver_lib.Saver(write_version=saver_write_version)
      checkpoint_path = saver.save(
          sess,
          checkpoint_prefix,
          global_step=0,
          latest_filename=checkpoint_state_name)
      graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)

    # We save out the graph to disk, and then call the const conversion
    # routine.
    input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
    input_saver_def_path = ""
    input_binary = False
    output_node_names = "output_node"
    restore_op_name = "save/restore_all"
    filename_tensor_name = "save/Const:0"
    output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
    clear_devices = False

    freeze_graph.freeze_graph(
        input_graph_path,
        input_saver_def_path,
        input_binary,
        checkpoint_path,
        output_node_names,
        restore_op_name,
        filename_tensor_name,
        output_graph_path,
        clear_devices,
        "",
        "",
        "",
        checkpoint_version=saver_write_version)

    # Now we make sure the variable is now a constant, and that the graph still
    # produces the expected result.
    with ops.Graph().as_default():
      output_graph_def = graph_pb2.GraphDef()
      with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = importer.import_graph_def(output_graph_def, name="")

      self.assertEqual(4, len(output_graph_def.node))
      for node in output_graph_def.node:
        self.assertNotEqual("VariableV2", node.op)
        self.assertNotEqual("Variable", node.op)

      with session.Session() as sess:
        output_node = sess.graph.get_tensor_by_name("output_node:0")
        output = sess.run(output_node)
        self.assertNear(2.0, output, 0.00001)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:68,代码来源:freeze_graph_test.py


示例9: testSinglePartitionedVariable

  def testSinglePartitionedVariable(self):
    """Ensures partitioned variables fail cleanly with freeze graph."""
    checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
    checkpoint_state_name = "checkpoint_state"
    input_graph_name = "input_graph.pb"
    output_graph_name = "output_graph.pb"

    # Create a graph with partition variables. When weights are partitioned into
    # a single partition, the weights variable is followed by a identity ->
    # identity (an additional identity node).
    partitioner = partitioned_variables.fixed_size_partitioner(1)
    with ops.Graph().as_default():
      with variable_scope.variable_scope("part", partitioner=partitioner):
        batch_size, height, width, depth = 5, 128, 128, 3
        input1 = array_ops.zeros(
            (batch_size, height, width, depth), name="input1")
        input2 = array_ops.zeros(
            (batch_size, height, width, depth), name="input2")

        num_nodes = depth
        filter1 = variable_scope.get_variable("filter", [num_nodes, num_nodes])
        filter2 = array_ops.reshape(filter1, [1, 1, num_nodes, num_nodes])
        conv = nn.conv2d(
            input=input1, filter=filter2, strides=[1, 1, 1, 1], padding="SAME")
        node = math_ops.add(conv, input2, name="test/add")
        node = nn.relu6(node, name="test/relu6")

      # Save graph and checkpoints.
      sess = session.Session()
      sess.run(variables.global_variables_initializer())

      saver = saver_lib.Saver()
      checkpoint_path = saver.save(
          sess,
          checkpoint_prefix,
          global_step=0,
          latest_filename=checkpoint_state_name)
      graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)

      # Ensure this graph has partition variables.
      self.assertTrue([
          tensor.name.split(":")[0]
          for op in sess.graph.get_operations()
          for tensor in op.values()
          if re.search(r"/part_\d+/", tensor.name)
      ])

    # Test freezing graph doesn't make it crash.
    output_node_names = "save/restore_all"
    output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)

    return_value = freeze_graph.freeze_graph_with_def_protos(
        input_graph_def=sess.graph_def,
        input_saver_def=None,
        input_checkpoint=checkpoint_path,
        output_node_names=output_node_names,
        restore_op_name="save/restore_all",  # default value
        filename_tensor_name="save/Const:0",  # default value
        output_graph=output_graph_path,
        clear_devices=False,
        initializer_nodes="")
    self.assertTrue(return_value, -1)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:62,代码来源:freeze_graph_test.py


示例10: export_scoped_meta_graph


#.........这里部分代码省略.........
        graph (both Save/Restore ops and SaverDefs) that are not associated
        with the provided SaverDef.
    strip_default_attrs: Set to true if default valued attributes must be
        removed while exporting the GraphDef.
    **kwargs: Optional keyed arguments, including meta_info_def and
        collection_list.

  Returns:
    A `MetaGraphDef` proto and dictionary of `Variables` in the exported
    name scope.

  Raises:
    ValueError: When the `GraphDef` is larger than 2GB.
  """
  if context.executing_eagerly():
    raise ValueError("Exporting/importing meta graphs is not supported when "
                     "Eager Execution is enabled.")
  graph = graph or ops.get_default_graph()

  exclude_nodes = None
  unbound_inputs = []
  if export_scope or clear_extraneous_savers or clear_devices:
    if graph_def:
      new_graph_def = graph_pb2.GraphDef()
      new_graph_def.versions.CopyFrom(graph_def.versions)
      new_graph_def.library.CopyFrom(graph_def.library)

      if clear_extraneous_savers:
        exclude_nodes = _find_extraneous_saver_nodes(graph_def, saver_def)

      for node_def in graph_def.node:
        if _should_include_node(node_def.name, export_scope, exclude_nodes):
          new_node_def = _node_def(node_def, export_scope, unbound_inputs,
                                   clear_devices=clear_devices)
          new_graph_def.node.extend([new_node_def])
      graph_def = new_graph_def
    else:
      # Only do this complicated work if we want to remove a name scope.
      graph_def = graph_pb2.GraphDef()
      # pylint: disable=protected-access
      graph_def.versions.CopyFrom(graph.graph_def_versions)
      bytesize = 0

      if clear_extraneous_savers:
        exclude_nodes = _find_extraneous_saver_nodes(graph.as_graph_def(),
                                                     saver_def)

      for key in sorted(graph._nodes_by_id):
        if _should_include_node(graph._nodes_by_id[key].name,
                                export_scope,
                                exclude_nodes):
          value = graph._nodes_by_id[key]
          # pylint: enable=protected-access
          node_def = _node_def(value.node_def, export_scope, unbound_inputs,
                               clear_devices=clear_devices)
          graph_def.node.extend([node_def])
          if value.outputs:
            assert "_output_shapes" not in graph_def.node[-1].attr
            graph_def.node[-1].attr["_output_shapes"].list.shape.extend([
                output.get_shape().as_proto() for output in value.outputs])
          bytesize += value.node_def.ByteSize()
          if bytesize >= (1 << 31) or bytesize < 0:
            raise ValueError("GraphDef cannot be larger than 2GB.")

      graph._copy_functions_to_graph_def(graph_def, bytesize)  # pylint: disable=protected-access

    # It's possible that not all the inputs are in the export_scope.
    # If we would like such information included in the exported meta_graph,
    # add them to a special unbound_inputs collection.
    if unbound_inputs_col_name:
      # Clears the unbound_inputs collections.
      graph.clear_collection(unbound_inputs_col_name)
      for k in unbound_inputs:
        graph.add_to_collection(unbound_inputs_col_name, k)

  var_list = {}
  variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
                                   scope=export_scope)
  for v in variables:
    if _should_include_node(v, export_scope, exclude_nodes):
      var_list[ops.strip_name_scope(v.name, export_scope)] = v

  scoped_meta_graph_def = create_meta_graph_def(
      graph_def=graph_def,
      graph=graph,
      export_scope=export_scope,
      exclude_nodes=exclude_nodes,
      clear_extraneous_savers=clear_extraneous_savers,
      saver_def=saver_def,
      strip_default_attrs=strip_default_attrs,
      **kwargs)

  if filename:
    graph_io.write_graph(
        scoped_meta_graph_def,
        os.path.dirname(filename),
        os.path.basename(filename),
        as_text=as_text)

  return scoped_meta_graph_def, var_list
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:101,代码来源:meta_graph.py


示例11: load_model

weight_file_path = osp.join(input_fld, weight_file)

K.set_learning_phase(0)
net_model = load_model(weight_file_path)


print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)

sess = K.get_session()

frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])

from tensorflow.python.framework import graph_io

graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)

print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))


# --------

from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import utils
from tensorflow.python.saved_model import tag_constants, signature_constants
from tensorflow.python.saved_model.signature_def_utils_impl import  build_signature_def, predict_signature_def
from tensorflow.contrib.session_bundle import exporter
export_path = 'folder_to_export'
builder = saved_model_builder.SavedModelBuilder(export_path)
signature = predict_signature_def(inputs={'images': net_model.input},
                                  outputs={'scores': net_model.output})
开发者ID:DXZ,项目名称:git_test,代码行数:31,代码来源:keras_to_tfsevring_final.py


示例12: print


# [optional] write graph definition in ascii

# In[ ]:

sess = K.get_session()

if args.graph_def:
    f = args.output_graphdef_file 
    tf.train.write_graph(sess.graph.as_graph_def(), output_fld, f, as_text=True)
    print('saved the graph definition in ascii format at: ', str(Path(output_fld) / f))


# convert variables to constants and save

# In[ ]:

from tensorflow.python.framework import graph_util
from tensorflow.python.framework import graph_io
if args.quantize:
    from tensorflow.tools.graph_transforms import TransformGraph
    transforms = ["quantize_weights", "quantize_nodes"]
    transformed_graph_def = TransformGraph(sess.graph.as_graph_def(), [], pred_node_names, transforms)
    constant_graph = graph_util.convert_variables_to_constants(sess, transformed_graph_def, pred_node_names)
else:
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), pred_node_names)    
graph_io.write_graph(constant_graph, output_fld, args.output_model_file, as_text=False)
print('saved the freezed graph (ready for inference) at: ', str(Path(output_fld) / args.output_model_file))

开发者ID:gustavelarrouturou,项目名称:keras_to_tensorflow,代码行数:27,代码来源:keras_to_tensorflow.py


示例13: testStripUnused

  def testStripUnused(self):
    input_graph_name = "input_graph.pb"
    output_graph_name = "output_graph.pb"

    # We'll create an input graph that has a single constant containing 1.0,
    # and that then multiplies it by 2.
    with ops.Graph().as_default():
      constant_node = constant_op.constant(1.0, name="constant_node")
      wanted_input_node = math_ops.subtract(constant_node,
                                            3.0,
                                            name="wanted_input_node")
      output_node = math_ops.multiply(
          wanted_input_node, 2.0, name="output_node")
      math_ops.add(output_node, 2.0, name="later_node")
      sess = session.Session()
      output = sess.run(output_node)
      self.assertNear(-4.0, output, 0.00001)
      graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)

    # We save out the graph to disk, and then call the const conversion
    # routine.
    input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
    input_binary = False
    output_binary = True
    output_node_names = "output_node"
    output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)

    def strip(input_node_names):
      strip_unused_lib.strip_unused_from_files(input_graph_path, input_binary,
                                               output_graph_path, output_binary,
                                               input_node_names,
                                               output_node_names,
                                               dtypes.float32.as_datatype_enum)

    with self.assertRaises(KeyError):
      strip("does_not_exist")

    with self.assertRaises(ValueError):
      strip("wanted_input_node:0")

    input_node_names = "wanted_input_node"
    strip(input_node_names)

    # Now we make sure the variable is now a constant, and that the graph still
    # produces the expected result.
    with ops.Graph().as_default():
      output_graph_def = graph_pb2.GraphDef()
      with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = importer.import_graph_def(output_graph_def, name="")

      self.assertEqual(3, len(output_graph_def.node))
      for node in output_graph_def.node:
        self.assertNotEqual("Add", node.op)
        self.assertNotEqual("Sub", node.op)
        if node.name == input_node_names:
          self.assertTrue("shape" in node.attr)

      with session.Session() as sess:
        input_node = sess.graph.get_tensor_by_name("wanted_input_node:0")
        output_node = sess.graph.get_tensor_by_name("output_node:0")
        output = sess.run(output_node, feed_dict={input_node: [10.0]})
        self.assertNear(20.0, output, 0.00001)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:63,代码来源:strip_unused_test.py


示例14: main

def main(argv):
  argparser = argparse.ArgumentParser(description='Compile some op')
  argparser.add_argument('config', help="filename to config-file")
  argparser.add_argument('--train', type=int, default=0, help='0 disable (default), 1 enable, -1 dynamic')
  argparser.add_argument('--eval', type=int, default=0, help='calculate losses. 0 disable (default), 1 enable')
  argparser.add_argument('--search', type=int, default=0, help='beam search. 0 disable (default), 1 enable')
  argparser.add_argument("--verbosity", default=4, type=int, help="5 for all seqs (default: 4)")
  argparser.add_argument("--summaries_tensor_name")
  argparser.add_argument("--output_file", help='output pb, pbtxt or meta, metatxt file')
  argparser.add_argument("--output_file_model_params_list", help="line-based, names of model params")
  argparser.add_argument("--output_file_state_vars_list", help="line-based, name of state vars")
  args = argparser.parse_args(argv[1:])
  assert args.train in [0, 1, 2] and args.eval in [0, 1] and args.search in [0, 1]
  init(config_filename=args.config, log_verbosity=args.verbosity)
  with tf.Graph().as_default() as graph:
    assert isinstance(graph, tf.Graph)
    print("Create graph...")
    # See :func:`Engine._init_network`.
    tf.set_random_seed(42)
    if args.train < 0:
      from TFUtil import get_global_train_flag_placeholder
      train_flag = get_global_train_flag_placeholder()
    else:
      train_flag = bool(args.train)
    eval_flag = bool(args.eval)
    search_flag = bool(args.search)
    network = create_graph(train_flag=train_flag, eval_flag=eval_flag, search_flag=search_flag)

    from TFNetworkLayer import LayerBase
    for layer in network.layers.values():
      assert isinstance(layer, LayerBase)
      if layer.output.time_dim_axis is None:
        continue
      with layer.cls_layer_scope(layer.name):
        tf.identity(layer.output.get_placeholder_as_batch_major(), name="output_batch_major")

    tf.group(*network.get_post_control_dependencies(), name="post_control_dependencies")

    if args.summaries_tensor_name:
      summaries_tensor = tf.summary.merge_all()
      assert isinstance(summaries_tensor, tf.Tensor), "no summaries in the graph?"
      tf.identity(summaries_tensor, name=args.summaries_tensor_name)

    if args.output_file and os.path.splitext(args.output_file)[1] in [".meta", ".metatxt"]:
      # https://www.tensorflow.org/api_guides/python/meta_graph
      saver = tf.train.Saver(
        var_list=network.get_saveable_params_list(), max_to_keep=2 ** 31 - 1)
      graph_def = saver.export_meta_graph()
    else:
      graph_def = graph.as_graph_def(add_shapes=True)

    print("Graph collection keys:", graph.get_all_collection_keys())
    print("Graph num operations:", len(graph.get_operations()))
    print("Graph def size:", Util.human_bytes_size(graph_def.ByteSize()))

    if args.output_file:
      filename = args.output_file
      _, ext = os.path.splitext(filename)
      assert ext in [".pb", ".pbtxt", ".meta", ".metatxt"], 'filename %r extension invalid' % filename
      print("Write graph to file:", filename)
      graph_io.write_graph(
        graph_def,
        logdir=os.path.dirname(filename),
        name=os.path.basename(filename),
        as_text=ext.endswith("txt"))
    else:
      print("Use --output_file if you want to store the graph.")

    if args.output_file_model_params_list:
      print("Write model param list to:", args.output_file_model_params_list)
      with open(args.output_file_model_params_list, "w") as f:
        for param in network.get_params_list():
          assert param.name[-2:] == ":0"
          f.write("%s\n" % param.name[:-2])

    if args.output_file_state_vars_list:
      print("Write state var list to:", args.output_file_state_vars_list)
      from TFUtil import CollectionKeys
      with open(args.output_file_state_vars_list, "w") as f:
        for param in tf.get_collection(CollectionKeys.STATE_VARS):
          assert param.name[-2:] == ":0"
          f.write("%s\n" % param.name[:-2])
开发者ID:rwth-i6,项目名称:returnn,代码行数:82,代码来源:compile_tf_graph.py


示例15: single_worker_inference

def single_worker_inference(infer_model,
                            ckpt,
                            inference_input_file,
                            inference_output_file,
                            hparams):
    """Inference with a single worker."""
    output_infer = inference_output_file

    # Read data
    infer_data = load_data(inference_input_file, hparams)
    print ("Batch size type:", type(hparams.infer_batch_size))
    with tf.Session(
            graph=infer_model.graph, config=utils.get_config_proto()) as sess:
        # revo debug
        # sess = tf_debug.TensorBoardDebugWrapperSession(sess, 'xy:6064')
        # initi table
        # sess.run(infer_model.insert_op[0])
        # sess.run(infer_model.insert_op[1])
        # sess.run(infer_model.insert_op[2])
        #
        loaded_infer_model = model_helper.load_model(
            infer_model.model, ckpt, sess, "infer", infer_model.insert_op)
        sess.run(
            infer_model.iterator.initializer,
            feed_dict={
                infer_model.src_placeholder: infer_data,
                infer_model.batch_size_placeholder: hparams.infer_batch_size
            })
        # Debug By Revo
        # value = sess.run(infer_model.iterator.source)
        # print ("Value:", value)
        # print ("Value,len:", len(value))
        # print ("Value,Type:", type(value))
        # print ("Value,Shape:", value.shape)
        # tmp_i = sess.run(infer_model.iterator)
        # print ("iterator:", tmp_i)
        # print ("iterator shape:", tmp_i.shape())
        # sys.exit()

        # print ("TEST")
        # # Initialize keys and values.
        # keys = tf.constant([1, 2, 3], dtype=tf.int64)
        # vals = tf.constant([1, 2, 3], dtype=tf.int64)
        # # Initialize hash table.
        # table = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int64, value_dtype=tf.int64, default_value=-1,
        #                                                 empty_key=0)
        # # Insert values to hash table and run the op.
        # insert_op = table.insert(keys, vals)
        # sess.run(insert_op)
        # # Print hash table lookups.
        # print(sess.run(table.lookup(keys)))
        # print("HERE2")

        # Saving Decoder model

        # Decode
        utils.print_out("# Start decoding ff3")
        # print ("indices:", hparams.inference_indices)

        if hparams.inference_indices:
            _decode_inference_indices(
                loaded_infer_model,
                sess,
                output_infer=output_infer,
                output_infer_summary_prefix=output_infer,
                inference_indices=hparams.inference_indices,
                tgt_eos=hparams.eos,
                subword_option=hparams.subword_option)
        else:
            nmt_utils.decode_and_evaluate(
                "infer",
                loaded_infer_model,
                sess,
                output_infer,
                ref_file=None,
                metrics=hparams.metrics,
                subword_option=hparams.subword_option,
                beam_width=hparams.beam_width,
                tgt_eos=hparams.eos,
                num_translations_per_input=hparams.num_translations_per_input)
        # saving model
        OUTPUT_FOLDER = '7.19'
        utils.print_out("Ouput Folder : " + OUTPUT_FOLDER)
        utils.print_out("# Saving Decoder model (Normal,ckpt) By Revo")
        loaded_infer_model.saver.save(sess, OUTPUT_FOLDER+"/current.ckpt")
        # save pb file
        graph_io.write_graph(sess.graph_def, OUTPUT_FOLDER, "current.graphdef")
        tf.train.export_meta_graph(filename=OUTPUT_FOLDER + '/current.meta')
        writer = tf.summary.FileWriter(OUTPUT_FOLDER, sess.graph)
        writer.close()
        # Frozen graph saving
        OUTPUT_FROZEN_FILE = 'nmt.pb'
        # OUTPUT_NAMES = ['index_to_string_Lookup', 'table_init', 'batch_iter_init']
        # maybe it is not utf8 as output
        OUTPUT_NODES = ['reverse_table_Lookup']
        utils.print_out("# Saving Decoder model (Frozen) By Revo")
        # extract method try
        # new_graph_def = tf.graph_util.extract_sub_graph(sess.graph_def, ["hash_table_2_Lookup"])
        #
        # remove train node
#.........这里部分代码省略.........
开发者ID:lkfo415579,项目名称:nmt_custom,代码行数:101,代码来源:inference.py


示例16: save

    def save(self, path=None, name=None, overwrite=True):
        path = "Models" if path is None else path
        name = "Cache" if name is None else name
        folder = os.path.join(path, name)
        if not os.path.exists(folder):
            os.makedirs(folder)
        _dir = os.path.join(folder, "Model")
        if os.path.isfile(_dir):
            if not overwrite:
                _count = 1
                _new_dir = _dir + "({})".format(_count)
                while os.path.isfile(_new_dir):
                    _count += 1
                    _new_dir = _dir + "({})".format(_count)
                _dir = _new_dir
            else:
                os.remove(_dir)

        print()
        print("=" * 60)
        print("Saving Model to {}...".format(folder))
        print("-" * 60)

        with open(_dir + ".nn", "wb") as file:
            # We don't need w_stds & b_inits when we load a model
            _dic = {
                "structures": {
                    "_lr": self._lr,
                    "_layer_names": self.layer_names,
                    "_layer_params": self._layer_params,
                    "_next_dimension": self._current_dimension
                },
                "params": {
                    "_logs": self._logs,
                    "_metric_names": self._metric_names,
                    "_optimizer": self._optimizer.name,
                    "layer_special_params": self.layer_special_params
                }
            }
            pickle.dump(_dic, file)
        saver = tf.train.Saver()
        saver.save(self._sess, _dir)
        graph_io.write_graph(self._sess.graph, os.path.join(path, name), "Model.pb", False)
        with tf.name_scope("OutputFlow"):
            self.get_rs(self._tfx)
        _output = ""
        for op in self._sess.graph.get_operations()[::-1]:
            if "OutputFlow" in op.name:
                _output = op.name
                break
        with open(os.path.join(path, name, "IO.txt"), "w") as file:
            file.write("\n".join([
                "Input  : Entry/Placeholder:0",
                "Output : {}:0".format(_output)
            ]))
        graph_io.write_graph(self._sess.graph, os.path.join(path, name), "Cache.pb", False)
        freeze_graph.freeze_graph(
            os.path.join(path, name, "Cache.pb"),
            "", True, os.path.join(path, name, "Model"),
            _output, "save/restore_all", "save/Const:0",
            os.path.join(path, name, "Frozen.pb"), True, ""
        )
        os.remove(os.path.join(path, name, "Cache.pb"))

        print("Done")
        print("=" * 60)
开发者ID:bitores,项目名称:MachineLearning,代码行数:66,代码来源:Networks.py



注:本文中的tensorflow.python.framework.graph_io.write_graph函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python graph_to_function_def.graph_to_function_def函数代码示例发布时间:2022-05-27
下一篇:
Python function.define_function函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap