• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python compat.as_str函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.util.compat.as_str函数的典型用法代码示例。如果您正苦于以下问题:Python as_str函数的具体用法?Python as_str怎么用?Python as_str使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了as_str函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: meta_graph_transform

def meta_graph_transform(
    base_meta_graph_def, input_names, output_names, transforms, tags,
    checkpoint_path=None):
  """Apply the Graph Transform tool to a MetaGraphDef.

  Args:
    base_meta_graph_def: A MetaGraphDef protocol buffer to transform.
    input_names: Names of input nodes.
    output_names: Names of output nodes.
    transforms: A list of strings naming the graph transforms to be applied in
      order.  These transform names are exactly those supported by the Graph
      Transform Tool, with the addition of the 'freeze_graph' transform.
    tags: A list of tags with which to annotate the transformed MetaGraphDef.
    checkpoint_path: A path to a checkpoint to restore during freezing,
      if needed (default None).

  Returns:
    A new transformed MetaGraphDef protocol buffer.
  """
  meta_graph_def = _meta_graph_pb2.MetaGraphDef()

  initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def)

  transformed_graph_def = _do_transforms(
      base_meta_graph_def.graph_def,
      input_names,
      output_names,
      initializer_names,
      transforms,
      base_meta_graph_def.saver_def,
      checkpoint_path)

  meta_graph_def.graph_def.CopyFrom(transformed_graph_def)
  meta_graph_def.meta_info_def.CopyFrom(base_meta_graph_def.meta_info_def)
  meta_graph_def.meta_info_def.ClearField('tags')
  for tag in tags:
    meta_graph_def.meta_info_def.tags.append(tag)

  base_op_names = [compat.as_str(node.name)
                   for node in base_meta_graph_def.graph_def.node]
  retained_op_names = [compat.as_str(node.name)
                       for node in meta_graph_def.graph_def.node]
  removed_op_names = set(base_op_names) - set(retained_op_names)

  # Copy saver, excluding any pruned nodes
  _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names)

  # Copy collections, excluding any pruned nodes
  for collection_name in base_meta_graph_def.collection_def:
    _add_pruned_collection(
        base_meta_graph_def, meta_graph_def, collection_name,
        removed_op_names)

  # Copy signature_defs, excluding any pruned nodes
  for signature_name in base_meta_graph_def.signature_def:
    _add_pruned_signature(
        base_meta_graph_def, meta_graph_def, signature_name,
        removed_op_names)

  return meta_graph_def
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:60,代码来源:meta_graph_transform.py


示例2: _PopulateTFImportGraphDefOptions

def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                     return_elements):
  """Populates the TF_ImportGraphDefOptions `options`."""
  c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix)
  c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True)
  c_api.TF_ImportGraphDefOptionsSetUniquifyPrefix(options, True)

  for input_src, input_dst in input_map.items():
    input_src = compat.as_str(input_src)
    if input_src.startswith('^'):
      src_name = compat.as_bytes(input_src[1:])
      dst_op = input_dst._as_tf_output().oper  # pylint: disable=protected-access
      c_api.TF_ImportGraphDefOptionsRemapControlDependency(options, src_name,
                                                           dst_op)
    else:
      src_name, src_idx = _ParseTensorName(input_src)
      src_name = compat.as_str(src_name)
      dst_output = input_dst._as_tf_output()  # pylint: disable=protected-access
      c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name,
                                                    src_idx, dst_output)
  for name in return_elements or []:
    if ':' in name:
      op_name, index = _ParseTensorName(name)
      op_name = compat.as_str(op_name)
      c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index)
    else:
      c_api.TF_ImportGraphDefOptionsAddReturnOperation(options,
                                                       compat.as_str(name))
开发者ID:andrewharp,项目名称:tensorflow,代码行数:28,代码来源:importer.py


示例3: _init_from_proto

  def _init_from_proto(self, hparam_def):
    """Creates a new HParams from `HParamDef` protocol buffer.

    Args:
      hparam_def: `HParamDef` protocol buffer.
    """
    assert isinstance(hparam_def, hparam_pb2.HParamDef)
    for name, value in hparam_def.hparam.items():
      kind = value.WhichOneof('kind')
      if kind.endswith('_value'):
        # Single value.
        if kind.startswith('int64'):
          # Setting attribute value to be 'int' to ensure the type is compatible
          # with both Python2 and Python3.
          self.add_hparam(name, int(getattr(value, kind)))
        elif kind.startswith('bytes'):
          # Setting attribute value to be 'str' to ensure the type is compatible
          # with both Python2 and Python3. UTF-8 encoding is assumed.
          self.add_hparam(name, compat.as_str(getattr(value, kind)))
        else:
          self.add_hparam(name, getattr(value, kind))
      else:
        # List of values.
        if kind.startswith('int64'):
          # Setting attribute value to be 'int' to ensure the type is compatible
          # with both Python2 and Python3.
          self.add_hparam(name, [int(v) for v in getattr(value, kind).value])
        elif kind.startswith('bytes'):
          # Setting attribute value to be 'str' to ensure the type is compatible
          # with both Python2 and Python3. UTF-8 encoding is assumed.
          self.add_hparam(
              name, [compat.as_str(v) for v in getattr(value, kind).value])
        else:
          self.add_hparam(name, [v for v in getattr(value, kind).value])
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:34,代码来源:hparam.py


示例4: _ProcessReturnElementsParam

def _ProcessReturnElementsParam(return_elements):
  """Type-checks and possibly canonicalizes `return_elements`."""
  if return_elements is None: return None
  if not all(isinstance(x, compat.bytes_or_text_types)
             for x in return_elements):
    raise TypeError('return_elements must be a list of strings.')
  return tuple(compat.as_str(x) for x in return_elements)
开发者ID:andrewharp,项目名称:tensorflow,代码行数:7,代码来源:importer.py


示例5: _clean_save_and_restore

def _clean_save_and_restore(graph_def, op, removed_op_names):
  """Clean the specified save and restore op.

  Updates the dtypes attribute of the save / restore op and the associated name
  and shape tensors to remove entries for variables that have been removed.

  Args:
    graph_def: A GraphDef proto to be transformed.
    op: The save or restore op to update.
    removed_op_names: List of op names that have been removed.
  """
  name = op.name + '/tensor_names'
  shape = op.name + '/shape_and_slices'
  name_op = _find_op(graph_def, name)
  shape_op = _find_op(graph_def, shape)
  name_op_value_tensor = name_op.attr['value'].tensor
  shape_op_value_tensor = shape_op.attr['value'].tensor
  names = []
  shapes = []
  dtypes = []
  for index, value in enumerate(name_op_value_tensor.string_val):
    if not _is_removed(compat.as_str(value), removed_op_names):
      names.append(value)
      shapes.append(shape_op_value_tensor.string_val[index])
      dtypes.append(op.attr['dtypes'].list.type[index])
  name_op_value_tensor.string_val[:] = names
  name_op_value_tensor.tensor_shape.dim[0].size = len(names)
  shape_op_value_tensor.string_val[:] = shapes
  shape_op_value_tensor.tensor_shape.dim[0].size = len(shapes)
  op.attr['dtypes'].list.type[:] = dtypes

  name_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(names)
  shape_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(shapes)
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:33,代码来源:meta_graph_transform.py


示例6: encode_arg

  def encode_arg(arg, path):
    """A representation for this argument, for converting into signatures."""
    if isinstance(arg, ops.Tensor):
      user_specified_name = None
      try:
        user_specified_name = compat.as_str(
            arg.op.get_attr("_user_specified_name"))
      except ValueError:
        pass

      if path and user_specified_name and user_specified_name != path[0]:
        # The user has explicitly named the argument differently than the name
        # of the function argument.
        name = user_specified_name
      else:
        name = "/".join([str(p) for p in path])
      return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
    if isinstance(arg, (
        int,
        float,
        bool,
        type(None),
        dtypes.DType,
        tensor_spec.TensorSpec,
    )):
      return arg
    return UnknownArgument()
开发者ID:kylin9872,项目名称:tensorflow,代码行数:27,代码来源:func_graph.py


示例7: assert_equal_graph_def

def assert_equal_graph_def(actual, expected, checkpoint_v2=False):
  """Asserts that two `GraphDef`s are (mostly) the same.

  Compares two `GraphDef` protos for equality, ignoring versions and ordering of
  nodes, attrs, and control inputs.  Node names are used to match up nodes
  between the graphs, so the naming of nodes must be consistent.

  Args:
    actual: The `GraphDef` we have.
    expected: The `GraphDef` we expected.
    checkpoint_v2: boolean determining whether to ignore randomized attribute
        values that appear in V2 checkpoints.

  Raises:
    AssertionError: If the `GraphDef`s do not match.
    TypeError: If either argument is not a `GraphDef`.
  """
  if not isinstance(actual, graph_pb2.GraphDef):
    raise TypeError("Expected tf.GraphDef for actual, got %s" %
                    type(actual).__name__)
  if not isinstance(expected, graph_pb2.GraphDef):
    raise TypeError("Expected tf.GraphDef for expected, got %s" %
                    type(expected).__name__)

  if checkpoint_v2:
    _strip_checkpoint_v2_randomized(actual)
    _strip_checkpoint_v2_randomized(expected)

  diff = pywrap_tensorflow.EqualGraphDefWrapper(actual.SerializeToString(),
                                                expected.SerializeToString())
  if diff:
    raise AssertionError(compat.as_str(diff))
开发者ID:LUTAN,项目名称:tensorflow,代码行数:32,代码来源:test_util.py


示例8: _create_new_tf_function

def _create_new_tf_function(func_graph):
  """Converts func_graph to a TF_Function and adds it to the current graph.

  Args:
    func_graph: function._FuncGraph

  Returns:
    The name of the new TF_Function.
  """
  c_func = c_api.TF_GraphToFunction_wrapper(
      func_graph._c_graph,
      compat.as_str(func_graph.name),
      False,  # append_hash_to_fn_name
      None,  # opers
      [t._as_tf_output() for t in func_graph.inputs],
      [t._as_tf_output() for t in func_graph.outputs],
      [],
      None,  # opts
      None)  # description
  _ = c_api_util.ScopedTFFunction(c_func)

  # TODO(b/109833212): this sucks, we're serializing the TF_Function*,
  # deserializing it into a Python FunctionDef, then reserializing it to create
  # a new TF_Function that we add to the graph.
  fdef = _function.function_def_from_tf_function(c_func)
  defined_func = _function._from_definition(fdef)
  defined_func._sub_functions = func_graph._functions
  defined_func.add_to_graph(func_graph._outer_graph)

  return func_graph.name
开发者ID:godyd2702,项目名称:tensorflow,代码行数:30,代码来源:cond_v2_impl.py


示例9: _node_def

def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False):
  """Create a `NodeDef` proto with export_scope stripped.

  Args:
    from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
    export_scope: A `string` representing the name scope to remove.
    unbound_inputs: An array of unbound input names if they exist.
    clear_devices: Boolean which controls whether to clear device information
      from node_def. Default false.

  Returns:
    A `node_def_pb2.NodeDef` protocol buffer.
  """
  node_def = copy.deepcopy(from_node_def)
  for i, v in enumerate(node_def.input):
    if (export_scope and
        not node_def.input[i].lstrip("^").startswith(export_scope)):
      # Adds "$unbound_inputs_" prefix to the unbound name so they are easily
      # identifiable.
      node_def.input[i] = re.sub(r"([\^]|^)(.*)",
                                 r"\1" + _UNBOUND_INPUT_PREFIX + r"\2",
                                 compat.as_str(v))
      unbound_inputs.append(node_def.input[i])
    else:
      node_def.input[i] = ops.strip_name_scope(v, export_scope)
  node_def.name = compat.as_bytes(
      ops.strip_name_scope(from_node_def.name, export_scope))
  for k, v in six.iteritems(from_node_def.attr):
    if k == "_class":
      new_s = [compat.as_bytes(
          ops.strip_name_scope(s, export_scope)) for s in v.list.s
               if not export_scope or
               compat.as_str(s).split("@")[1].startswith(export_scope)]
      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
          list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
    elif node_def.op in ("Enter", "RefEnter") and k == "frame_name":
      if not export_scope or compat.as_str(v.s).startswith(export_scope):
        new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope))
      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s))
    else:
      node_def.attr[k].CopyFrom(v)

  if clear_devices:
    node_def.device = ""

  return node_def
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:46,代码来源:meta_graph.py


示例10: __init__

  def __init__(self, name, graph, operations, inputs, outputs, attrs):
    """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      operations: list of Operation; the subset of operations in the graph
        which will be in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
      attrs: dict mapping names of attributes to their AttrValue values
    """
    fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
        graph._c_graph,  # pylint: disable=protected-access
        compat.as_str(name),
        False,
        [o._c_op for o in operations],  # pylint: disable=protected-access
        [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
        [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
        [],
        None,
        compat.as_str(""))

    for name, attr_value in attrs.items():
      serialized = attr_value.SerializeToString()
      # TODO(iga): this creates and deletes a new TF_Status for every attr.
      # It might be worth creating a convenient way to re-use status.
      pywrap_tensorflow.TF_FunctionSetAttrValueProto(
          fn, compat.as_str(name), serialized)

    # TODO(apassos) avoid creating a FunctionDef (specially to grab the
    # signature, but also in general it's nice not to depend on it.
    with c_api_util.tf_buffer() as buffer_:
      pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
      proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
    function_def = function_pb2.FunctionDef()
    function_def.ParseFromString(compat.as_bytes(proto_data))
    if context.executing_eagerly():
      _register(fn)
    self.definition = function_def
    self.name = function_def.signature.name
    self.signature = function_def.signature
    self.grad_func_name = None
    self.python_grad_func = None
    self._c_func = c_api_util.ScopedTFFunction(fn)
    self._grad_func = None
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:46,代码来源:function.py


示例11: save

  def save(self, sess, save_path, global_step=None, latest_filename=None):
    """Saves variables.

    This method runs the ops added by the constructor for saving variables.
    It requires a session in which the graph was launched.  The variables to
    save must also have been initialized.

    The method returns the path of the newly created checkpoint file.  This
    path can be passed directly to a call to `restore()`.

    Args:
      sess: A Session to use to save the variables.
      save_path: String.  Path to the checkpoint filename.  If the saver is
        `sharded`, this is the prefix of the sharded checkpoint filename.
      global_step: If provided the global step number is appended to
        `save_path` to create the checkpoint filename. The optional argument
        can be a `Tensor`, a `Tensor` name or an integer.
      latest_filename: Optional name for the protocol buffer file that will
        contains the list of most recent checkpoint filenames.  That file,
        kept in the same directory as the checkpoint files, is automatically
        managed by the saver to keep track of recent checkpoints.  Defaults to
        'checkpoint'.

    Returns:
      A string: path at which the variables were saved.  If the saver is
        sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
        is the number of shards created.

    Raises:
      TypeError: If `sess` is not a `Session`.
      ValueError: If `latest_filename` contains path components.
    """
    if latest_filename is None:
      latest_filename = "checkpoint"

    if os.path.split(latest_filename)[0]:
      raise ValueError("'latest_filename' must not contain path components")

    if global_step is not None:
      if not isinstance(global_step, compat.integral_types):
        global_step = training_util.global_step(sess, global_step)
      checkpoint_file = "%s-%d" % (save_path, global_step)
    else:
      checkpoint_file = save_path
    save_path = os.path.dirname(save_path)
    if not isinstance(sess, session.SessionInterface):
      raise TypeError("'sess' must be a Session; %s" % sess)

    model_checkpoint_path = sess.run(
        self._save_tensor_name, {self._filename_tensor_name: checkpoint_file})
    model_checkpoint_path = compat.as_str(model_checkpoint_path)
    self._MaybeDeleteOldCheckpoints(model_checkpoint_path)
    update_checkpoint_state(save_path, model_checkpoint_path,
                            self.last_checkpoints, latest_filename)
    return model_checkpoint_path
开发者ID:hessenh,项目名称:Human-Activity-Recognition,代码行数:55,代码来源:saver.py


示例12: __init__

  def __init__(self, name, graph, operations, inputs, outputs):
    """Initializes an eager defined function.

    Args:
      name: str, the name for the created function.
      graph: Graph, the graph containing the operations in the function
      operations: list of Operation; the subset of operations in the graph
        which will be in the function
      inputs: the tensors in the graph to be used as inputs to the function
      outputs: the tensors in the graph which will be outputs to the function
    """
    with errors.raise_exception_on_not_ok_status() as status:
      fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
          graph._c_graph,  # pylint: disable=protected-access
          compat.as_str(name),
          False,
          [o._c_op for o in operations],  # pylint: disable=protected-access
          [t._as_tf_output() for t in inputs],  # pylint: disable=protected-access
          [t._as_tf_output() for t in outputs],  # pylint: disable=protected-access
          [],
          None,
          compat.as_str(""),
          status)
    # TODO(apassos) avoid creating a FunctionDef (specially to grab the
    # signature, but also in general it's nice not to depend on it.
    with c_api_util.tf_buffer() as buffer_:
      with errors.raise_exception_on_not_ok_status() as status:
        pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
      proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
    function_def = function_pb2.FunctionDef()
    function_def.ParseFromString(compat.as_bytes(proto_data))
    if context.executing_eagerly():
      _register(fn)
    self.definition = function_def
    self.name = function_def.signature.name
    self.signature = function_def.signature
    self.grad_func_name = None
    self.python_grad_func = None
    self._c_func = fn
    self._grad_func = None
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:40,代码来源:function.py


示例13: request_stop

  def request_stop(self, ex=None):
    """Request that the threads stop.

    After this is called, calls to `should_stop()` will return `True`.

    Args:
      ex: Optional `Exception`, or Python `exc_info` tuple as returned by
        `sys.exc_info()`.  If this is the first call to `request_stop()` the
        corresponding exception is recorded and re-raised from `join()`.
    """
    with self._lock:
      if not self._stop_event.is_set():
        if ex and self._exc_info_to_raise is None:
          if isinstance(ex, tuple):
            logging.info("Error reported to Coordinator: %s",
                         compat.as_str(unicode(ex[1])))
            self._exc_info_to_raise = ex
          else:
            logging.info("Error reported to Coordinator: %s",
                         compat.as_str(unicode(ex)))
            self._exc_info_to_raise = sys.exc_info()
        self._stop_event.set()
开发者ID:peace195,项目名称:tensorflow,代码行数:22,代码来源:coordinator.py


示例14: _set_c_attrs

  def _set_c_attrs(self, attrs):
    """Sets `attrs` as attributes of self._c_func.

    Requires that self._c_func is not None.

    Args:
      attrs: a dictionary from attribute name to attribute proto value
    """
    for name, attr_value in attrs.items():
      serialized = attr_value.SerializeToString()
      # TODO(skyewm): this creates and deletes a new TF_Status for every attr.
      # It might be worth creating a convenient way to re-use the same status.
      c_api.TF_FunctionSetAttrValueProto(self._c_func.func, compat.as_str(name),
                                         serialized)
开发者ID:didukhle,项目名称:tensorflow,代码行数:14,代码来源:function.py


示例15: _ReadAndCheckRowsUsingFeatures

  def _ReadAndCheckRowsUsingFeatures(self, num_rows):
    self.server.handler.num_rows = num_rows

    with self.test_session() as sess:
      feature_configs = {
          "int64_col":
              parsing_ops.FixedLenFeature(
                  [1], dtype=dtypes.int64),
          "string_col":
              parsing_ops.FixedLenFeature(
                  [1], dtype=dtypes.string, default_value="s_default"),
      }
      reader = cloud.BigQueryReader(
          project_id=_PROJECT,
          dataset_id=_DATASET,
          table_id=_TABLE,
          num_partitions=4,
          features=feature_configs,
          timestamp_millis=1,
          test_end_point=("%s:%s" % (self.server.httpd.server_address[0],
                                     self.server.httpd.server_address[1])))

      key, value = _SetUpQueue(reader)

      seen_rows = []
      features = parsing_ops.parse_example(
          array_ops.reshape(value, [1]), feature_configs)
      for _ in range(num_rows):
        int_value, str_value = sess.run(
            [features["int64_col"], features["string_col"]])

        # Parse values returned from the session.
        self.assertEqual(int_value.shape, (1, 1))
        self.assertEqual(str_value.shape, (1, 1))
        int64_col = int_value[0][0]
        string_col = str_value[0][0]
        seen_rows.append(int64_col)

        # Compare.
        expected_row = _ROWS[int64_col]
        self.assertEqual(int64_col, expected_row[0])
        self.assertEqual(
            compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1]
            else "s_default")

      self.assertItemsEqual(seen_rows, range(num_rows))

      with self.assertRaisesOpError("is closed and has insufficient elements "
                                    "\\(requested 1, current size 0\\)"):
        sess.run([key, value])
开发者ID:brainwy12,项目名称:tensorflow,代码行数:50,代码来源:bigquery_reader_ops_test.py


示例16: _node_def

def _node_def(from_node_def, export_scope, unbound_inputs):
  """Create a `NodeDef` proto with export_scope stripped.

  Args:
    from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
    export_scope: A `string` representing the name scope to remove.
    unbound_inputs: An array of unbound input names if they exist.

  Returns:
    A `node_def_pb2.NodeDef` protocol buffer.
  """
  node_def = copy.deepcopy(from_node_def)
  for i, v in enumerate(node_def.input):
    if (export_scope and
        not node_def.input[i].lstrip("^").startswith(export_scope)):
      # Adds "$unbound_inputs_" prefix to the unbound name so they are easily
      # identifiable.
      node_def.input[i] = re.sub(r"([\^]|^)(.*)", r"\1$unbound_inputs_\2",
                                 compat.as_str(v))
      unbound_inputs.append(node_def.input[i])
    else:
      node_def.input[i] = ops.strip_name_scope(v, export_scope)
  node_def.name = compat.as_bytes(
      ops.strip_name_scope(from_node_def.name, export_scope))
  for k, v in six.iteritems(from_node_def.attr):
    if k == "_class":
      new_s = [compat.as_bytes(
          ops.strip_name_scope(s, export_scope)) for s in v.list.s
               if not export_scope or
               compat.as_str(s).split("@")[1].startswith(export_scope)]
      node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
          list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
    else:
      node_def.attr[k].CopyFrom(v)

  return node_def
开发者ID:821760408-sp,项目名称:tensorflow,代码行数:36,代码来源:meta_graph.py


示例17: _GetColocationNames

def _GetColocationNames(op):
  """Returns names of the ops that `op` should be colocated with."""
  colocation_names = []
  try:
    class_values = op.get_attr('_class')
  except ValueError:
    # No _class attr
    return
  for val in class_values:
    val = compat.as_str(val)
    if val.startswith('loc:@'):
      colocation_node_name = val[len('loc:@'):]
      if colocation_node_name != op.name:
        colocation_names.append(colocation_node_name)
  return colocation_names
开发者ID:andrewharp,项目名称:tensorflow,代码行数:15,代码来源:importer.py


示例18: lookup

 def lookup(self, name):
     """Looks up "name".
     Args:
       name: a string specifying the registry key for the candidate.
     Returns:
       Registered object if found
     Raises:
       LookupError: if "name" has not been registered.
     """
     name = compat.as_str(name)
     if name in self._registry:
         return self._registry[name][_TYPE_TAG]
     else:
         raise LookupError(
             "%s registry has no entry for: %s" % (self._name, name))
开发者ID:chengyang317,项目名称:information_pursuit,代码行数:15,代码来源:registry.py


示例19: canonicalize_signatures

def canonicalize_signatures(signatures):
  """Converts `signatures` into a dictionary of concrete functions."""
  if signatures is None:
    return {}
  if not isinstance(signatures, collections.Mapping):
    signatures = {
        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures}
  concrete_signatures = {}
  for signature_key, function in signatures.items():
    signature_function = _get_signature(function)
    if signature_function is None:
      raise ValueError(
          ("Expected a TensorFlow function to generate a signature for, but "
           "got {}. Only `tf.functions` with an input signature or "
           "concrete functions can be used as a signature.").format(function))

    # Re-wrap the function so that it returns a dictionary of Tensors. This
    # matches the format of 1.x-style signatures.
    # pylint: disable=cell-var-from-loop
    @def_function.function
    def signature_wrapper(**kwargs):
      structured_outputs = signature_function(**kwargs)
      return _normalize_outputs(
          structured_outputs, signature_function.name, signature_key)
    # TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names
    # always match keyword arguments.
    tensor_spec_signature = {}
    for keyword, tensor in zip(
        signature_function._arg_keywords,  # pylint: disable=protected-access
        signature_function.inputs):
      keyword = compat.as_str(keyword)
      tensor_spec_signature[keyword] = tensor_spec.TensorSpec.from_tensor(
          tensor, name=keyword)
    final_concrete = signature_wrapper.get_concrete_function(
        **tensor_spec_signature)
    # pylint: disable=protected-access
    if len(final_concrete._arg_keywords) == 1:
      # If there is only one input to the signature, a very common case, then
      # ordering is unambiguous and we can let people pass a positional
      # argument. Since SignatureDefs are unordered (protobuf "map") multiple
      # arguments means we need to be keyword-only.
      final_concrete._num_positional_args = 1
    else:
      final_concrete._num_positional_args = 0
    # pylint: enable=protected-access
    concrete_signatures[signature_key] = final_concrete
    # pylint: enable=cell-var-from-loop
  return concrete_signatures
开发者ID:aritratony,项目名称:tensorflow,代码行数:48,代码来源:signature_serialization.py


示例20: initialize_tpu_system

def initialize_tpu_system(cluster_resolver=None):
  """Initialize the TPU devices in a separate session and graph.

  Args:
    cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.contrib.tpu.Topology object for the topology of the TPU cluster.
  """
  if cluster_resolver is None:
    cluster_resolver = TPUClusterResolver("")
  master = cluster_resolver.master()

  logging.info("Initializing the TPU system.")

  if context.executing_eagerly():
    # This function looks as it is for the following non-intuitive reasons.
    # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
    # DistributedTPURewritePass. This pass actually adds real ops that
    # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
    # The easiest way to trigger a rewrite is to run the function with
    # TPUPartitionedCallOp.
    @function.defun
    def _tpu_init_fn():
      return tpu.initialize_system()

    # We can't call _tpu_init_fn normally (because it contains just a dummy op,
    # see above) but need to define it to get it added to eager context
    # and get its assigned name.
    # pylint: disable=protected-access
    graph_func = _tpu_init_fn._get_concrete_function_internal()
    func_name = compat.as_str(graph_func._inference_function.name)
    # pylint: enable=protected-access

    output = tpu_functional_ops.TPUPartitionedCall(
        args=[], device_ordinal=0, Tout=[dtypes.string], f=func_name)
    serialized_topology = output[0].numpy()
  else:
    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    with ops.Graph().as_default():
      with session_lib.Session(config=session_config, target=master) as sess:
        serialized_topology = sess.run(tpu.initialize_system())

  logging.info("Finished initializing TPU system.")
  return topology.Topology(serialized=serialized_topology)
开发者ID:jackd,项目名称:tensorflow,代码行数:46,代码来源:tpu_strategy.py



注:本文中的tensorflow.python.util.compat.as_str函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python compat.as_str_any函数代码示例发布时间:2022-05-27
下一篇:
Python compat.as_bytes函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap