• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python ops.add_to_collections函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.framework.ops.add_to_collections函数的典型用法代码示例。如果您正苦于以下问题:Python add_to_collections函数的具体用法?Python add_to_collections怎么用?Python add_to_collections使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了add_to_collections函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: var_creator

      def var_creator(*args, **kwargs):
        """Create an AggregatingVariable and fix up collections."""
        # Record what collections this variable should be added to.
        collections = kwargs.pop("collections", None)
        if collections is None:
          collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        kwargs["collections"] = []

        # Create and wrap the variable.
        v = next_creator(*args, **kwargs)
        wrapped = values.AggregatingVariable(v, aggregation)

        # Add the wrapped variable to the requested collections.
        # The handling of eager mode and the global step matches
        # ResourceVariable._init_from_args().
        if not context.executing_eagerly():
          g = ops.get_default_graph()
          # If "trainable" is True, next_creator() will add the contained
          # variable to the TRAINABLE_VARIABLES collection, so we manually
          # remove it and replace with the wrapper. We can't set "trainable"
          # to False for next_creator() since that causes functions like
          # implicit_gradients to skip those variables.
          if kwargs.get("trainable", True):
            collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
            l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
            l.remove(v)
          g.add_to_collections(collections, wrapped)
        elif ops.GraphKeys.GLOBAL_STEP in collections:
          ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, wrapped)

        return wrapped
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:31,代码来源:parameter_server_strategy.py


示例2: variable_creator_scope

  def variable_creator_scope(self, next_creator, **kwargs):
    """Creates variables & adds them to collections to match legacy code."""
    collections = kwargs.pop("collections", None)
    v = None

    # Get expected variable name.
    name = kwargs.get("name", None)
    with ops.name_scope(name, "Variable") as name_scope:
      name = name_scope

    if self._share_variables:
      v = self._variables_by_name.get(name, None)

    if v is None:
      v = next_creator(**kwargs)
      self._variables.append(v)
      if self._share_variables:
        self._variables_by_name[name] = v

    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    if v.trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]

    ops.add_to_collections(collections, v)

    return v
开发者ID:perfmjs,项目名称:tensorflow,代码行数:27,代码来源:wrap_function.py


示例3: __init__

  def __init__(self, initial_value, trainable=True, collections=None,
               validate_shape=True, name=None):
    """Creates a new variable with value `initial_value`.

    The new variable is added to the graph collections listed in `collections`,
    which defaults to `[GraphKeys.VARIABLES]`.

    If `trainable` is `True` the variable is also added to the graph collection
    `GraphKeys.TRAINABLE_VARIABLES`.

    This constructor creates both a `variable` Op and an `assign` Op to set the
    variable to its initial value.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`.
        The initial value for the Variable. Must have a shape specified unless
        `validate_shape` is set to False.
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.VARIABLES]`.
      validate_shape: If `False`, allows the variable to be initialized with a
        value of unknown shape. If `True`, the default, the shape of
        `initial_value` must be known.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.

    Returns:
      A Variable.

    Raises:
      ValueError: If the initial value does not have a shape and
        `validate_shape` is `True`.
    """
    if collections is None:
      collections = [ops.GraphKeys.VARIABLES]
    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
    with ops.control_dependencies(None):
      with ops.op_scope([initial_value], name, "Variable") as name:
        self._initial_value = ops.convert_to_tensor(initial_value,
                                                    name="initial_value")
        initial_value_shape = self._initial_value.get_shape()
        if validate_shape and not initial_value_shape.is_fully_defined():
          raise ValueError("initial_value must have a shape specified: %s"
                           % self._initial_value)
        shape_to_set = initial_value_shape if validate_shape else []
        self._variable = state_ops.variable_op(
            shape_to_set, self._initial_value.dtype.base_dtype,
            set_shape=validate_shape, name=name)
        with ops.device(self._variable.device):
          self._initializer_op = state_ops.assign(
              self._variable, self._initial_value,
              validate_shape=validate_shape).op
          self._snapshot = array_ops.identity(self._variable, name="read")

    ops.add_to_collections(collections, self)
    self._save_slice_info = None
开发者ID:Mandar-Shinde,项目名称:tensorflow,代码行数:59,代码来源:variables.py


示例4: _init_from_args

  def _init_from_args(self, initial_value=None, trainable=True,
                      collections=None, validate_shape=True,
                      caching_device=None, name=None):
    """Creates a new variable from arguments.

    Args:
      initial_value: A `Tensor`, or Python object convertible to a `Tensor`.
        The initial value for the Variable. Must have a shape specified unless
        `validate_shape` is set to False.
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.VARIABLES]`.
      validate_shape: If `False`, allows the variable to be initialized with a
        value of unknown shape. If `True`, the default, the shape of
        `initial_value` must be known.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
    """
    if initial_value is None:
      raise ValueError("initial_value must be specified.")
    if collections is None:
      collections = [ops.GraphKeys.VARIABLES]
    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
    with ops.control_dependencies(None):
      with ops.op_scope([initial_value], name, "Variable") as name:
        self._initial_value = ops.convert_to_tensor(initial_value,
                                                    name="initial_value")
        initial_value_shape = self._initial_value.get_shape()
        if validate_shape and not initial_value_shape.is_fully_defined():
          raise ValueError("initial_value must have a shape specified: %s"
                           % self._initial_value)
        shape_to_set = initial_value_shape if validate_shape else []
        self._variable = state_ops.variable_op(
            shape_to_set, self._initial_value.dtype.base_dtype,
            set_shape=validate_shape, name=name)
        with ops.device(self._variable.device):
          self._initializer_op = state_ops.assign(
              self._variable, self._initial_value,
              validate_shape=validate_shape).op
        with ops.device(caching_device if caching_device is not None
                        else self._variable.device):
          self._snapshot = array_ops.identity(self._variable, name="read")

    ops.add_to_collections(collections, self)
    self._caching_device = caching_device
    self._save_slice_info = None
开发者ID:chintanpanchamia,项目名称:tensorflow,代码行数:58,代码来源:variables.py


示例5: _register_dense_variable_read

def _register_dense_variable_read(read, collections, trainable):
  """Helper function to put a read from a dense variable in the collections."""
  if collections is None:
    collections = []
  if (trainable and
      ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES not in collections):
    collections = (list(collections) +
                   [ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES])
    ops.add_to_collections(collections, read)
开发者ID:brchiu,项目名称:tensorflow,代码行数:9,代码来源:resource_variable_ops.py


示例6: _init_from_args

 def _init_from_args(self, name):
   """Initialize the CriticalSection from constructor arguments."""
   with ops.name_scope(name, "CriticalSection", []) as name:
     with ops.control_dependencies(None):
       # pylint: disable=protected-access
       handle_name = ops._name_from_scope_name(name)
       container = ops.get_default_graph()._container
       # pylint: enable=protected-access
       if container is None:
         container = ""
       self._handle = gen_resource_variable_ops.critical_section_op(
           shared_name=handle_name, name=name)
   if context.in_graph_mode():
     ops.add_to_collections(CRITICAL_SECTIONS, self)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:14,代码来源:critical_section_ops.py


示例7: variable_creator_scope

  def variable_creator_scope(self, next_creator, **kwargs):
    """Creates variables & adds them to collections to match legacy code."""
    v = next_creator(**kwargs)
    self._variables.append(v)

    collections = kwargs.get("collections")
    trainable = v.trainable

    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]

    ops.add_to_collections(collections, v)

    return v
开发者ID:kylin9872,项目名称:tensorflow,代码行数:16,代码来源:wrap_function.py


示例8: _init_from_args

  def _init_from_args(self, name, shared_name):  # pylint: disable=invalid-name
    """Initialize the CriticalSection from constructor arguments."""
    with ops.name_scope(name, "CriticalSection", []) as name:
      with ops.init_scope():
        # pylint: disable=protected-access
        container = ops.get_default_graph()._container
        # pylint: enable=protected-access
        if shared_name is None:
          shared_name = name
        if container is None:
          container = ""
        self._handle = gen_resource_variable_ops.mutex_v2(
            shared_name=shared_name, container=container, name=name)

    if not context.executing_eagerly():
      ops.add_to_collections(CRITICAL_SECTIONS, self)
开发者ID:DILASSS,项目名称:tensorflow,代码行数:16,代码来源:critical_section_ops.py


示例9: _init_from_args

  def _init_from_args(self, name, shared_name):  # pylint: disable=invalid-name
    """Initialize the Notification from constructor arguments."""
    with ops.name_scope(name, "Notification", []) as name:
      with ops.init_scope():
        # pylint: disable=protected-access
        container = ops.get_default_graph()._container
        # pylint: enable=protected-access
        if shared_name is None:
          shared_name = name
        if container is None:
          container = ""
        # Build the notification resource outside of any control dependencies.
        with ops.control_dependencies(None):
          self._handle = gen_resource_variable_ops.notification(
              shared_name=shared_name, container=container, name=name)

    if not context.executing_eagerly():
      ops.add_to_collections(NOTIFICATIONS, self)
开发者ID:ebrevdo,项目名称:tensorflow,代码行数:18,代码来源:notification_ops.py


示例10: collect_named_outputs

def collect_named_outputs(collections, alias, outputs):
  """Add `Tensor` outputs tagged with alias to collections.

  It is useful to collect end-points or tags for summaries. Example of usage:

  logits = collect_named_outputs('end_points', 'inception_v3/logits', logits)
  assert 'inception_v3/logits' in logits.aliases

  Args:
    collections: A collection or list of collections. If None skip collection.
    alias: String to append to the list of aliases of outputs, for example,
           'inception_v3/conv1'.
    outputs: Tensor, an output tensor to collect

  Returns:
    The outputs Tensor to allow inline call.
  """
  append_tensor_alias(outputs, alias)
  if collections:
    ops.add_to_collections(collections, outputs)
  return outputs
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:21,代码来源:utils.py


示例11: collect_named_outputs

def collect_named_outputs(collections, name, outputs):
  """Add tuple (name, outputs) to collections.

  It is useful to collect end-points or tags for summaries. Example of usage:

  logits = collect_named_outputs('end_points', 'inception_v3/logits', logits)

  Args:
    collections: A collection or list of collections. If None skip collection.
    name: String, name to represent the outputs, ex. 'inception_v3/conv1'
    outputs: Tensor, an output tensor to collect

  Returns:
    The outputs Tensor to allow inline call.
  """
  if collections:
    # Remove ending '/' if present.
    if name[-1] == '/':
      name = name[:-1]
    ops.add_to_collections(collections, (name, outputs))
  return outputs
开发者ID:AI-MR-Related,项目名称:tensorflow,代码行数:21,代码来源:utils.py


示例12: collect_named_outputs

def collect_named_outputs(collections, alias, outputs):
  """Add `Tensor` outputs tagged with alias to collections.

  It is useful to collect end-points or tags for summaries. Example of usage:

  logits = collect_named_outputs('end_points', 'inception_v3/logits', logits)
  assert logits.alias == 'inception_v3/logits'

  Args:
    collections: A collection or list of collections. If None skip collection.
    alias: String, alias to name the outputs, ex. 'inception_v3/conv1'
    outputs: Tensor, an output tensor to collect

  Returns:
    The outputs Tensor to allow inline call.
  """
  # Remove ending '/' if present.
  if alias[-1] == '/':
    alias = alias[:-1]
  outputs.alias = alias
  if collections:
    ops.add_to_collections(collections, outputs)
  return outputs
开发者ID:DavidNemeskey,项目名称:tensorflow,代码行数:23,代码来源:utils.py


示例13: _init_from_args


#.........这里部分代码省略.........
        shape and `validate_shape` is `True`.
    """
        if initial_value is None:
            raise ValueError("initial_value must be specified.")
        init_from_fn = callable(initial_value)
        if init_from_fn and dtype is None:
            raise ValueError("dtype must also be specified when initial_value is callable.")

        if collections is None:
            collections = [ops.GraphKeys.GLOBAL_VARIABLES]
        if not isinstance(collections, (list, tuple, set)):
            raise ValueError(
                "collections argument to Variable constructor must be a list, tuple, "
                "or set. Got %s of type %s" % (collections, type(collections))
            )
        if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
            collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
        expected_shape = tensor_shape.as_shape(expected_shape)
        with ops.control_dependencies(None):
            with ops.name_scope(name, "Variable", [] if init_from_fn else [initial_value]) as name:

                # Get the initial value from a callable function. The real shape of the
                # variable will be set later, since under the init_from_fn case, the
                # shape won't be known until after the function is invoked.
                #
                # NOTE: The current Variable OpKernel does not support
                # partially defined shapes, so we only set the shape if it is
                # fully defined. For historical reasons, we use the scalar
                # shape (`[]`) to represent an unknown or partially known
                # shape. A future version of the Variable ops will remove this
                # limitation.
                def full_shape_to_list(shape):
                    """Returns shape as a list if shape is fully defined."""
                    if shape and shape.is_fully_defined():
                        return shape.as_list()
                    else:
                        return []

                def assert_expected_shape():
                    """Asserts that the initial value has the expected shape."""
                    if expected_shape:
                        expected_shape.assert_is_compatible_with(self._initial_value.get_shape())

                if init_from_fn:
                    expected_shape_list = full_shape_to_list(expected_shape)
                    set_shape = validate_shape and expected_shape.is_fully_defined()
                    self._variable = state_ops.variable_op(
                        expected_shape_list, dtype.base_dtype, set_shape=set_shape, name=name
                    )
                    with ops.colocate_with(self._variable.op):
                        with ops.name_scope("Initializer"):
                            # Colocate the tensors created by the initial_value() function
                            # with the variable itself.
                            self._initial_value = ops.convert_to_tensor(
                                initial_value(), name="initial_value", dtype=dtype
                            )
                            assert_expected_shape()

                # Or get the initial value from a Tensor or Python object.
                else:
                    self._initial_value = ops.convert_to_tensor(initial_value, name="initial_value", dtype=dtype)
                    assert_expected_shape()
                    set_shape = validate_shape and self._initial_value.get_shape().is_fully_defined()
                    # In this case, the variable op can't be created until after the
                    # initial_value has been converted to a Tensor with a known type.
                    self._variable = state_ops.variable_op(
                        full_shape_to_list(self._initial_value.get_shape()),
                        self._initial_value.dtype.base_dtype,
                        set_shape=set_shape,
                        name=name,
                    )

                # Manually overrides the variable's shape with the initial value's.
                if validate_shape:
                    initial_value_shape = self._initial_value.get_shape()
                    if not initial_value_shape.is_fully_defined():
                        raise ValueError("initial_value must have a shape specified: %s" % self._initial_value)
                    self._variable.set_shape(initial_value_shape)
                    # TODO(b/28152992): Remove the below hack modifying the node_def shape
                    # directly once set_shape() handles it.
                    self._variable.op.node_def.attr["shape"].shape.CopyFrom(initial_value_shape.as_proto())

                # Assigns initial value.
                self._initializer_op = state_ops.assign(
                    self._variable, self._initial_value, validate_shape=validate_shape
                ).op

                # TODO(vrv): Change this class to not take caching_device, but
                # to take the op to colocate the snapshot with, so we can use
                # colocation rather than devices.
                if caching_device is not None:
                    with ops.device(caching_device):
                        self._snapshot = array_ops.identity(self._variable, name="read")
                else:
                    with ops.colocate_with(self._variable.op):
                        self._snapshot = array_ops.identity(self._variable, name="read")

        ops.add_to_collections(collections, self)
        self._caching_device = caching_device
        self._save_slice_info = None
开发者ID:shakamunyi,项目名称:tensorflow,代码行数:101,代码来源:variables.py


示例14: _init_from_args


#.........这里部分代码省略.........
            with ops.name_scope("Initializer"):
              initial_value = ops.convert_to_tensor(
                  initial_value, name="initial_value", dtype=dtype)
            self._handle = _eager_safe_variable_handle(
                shape=initial_value.get_shape(),
                dtype=initial_value.dtype.base_dtype,
                shared_name=handle_name,
                name=name,
                graph_mode=False)
            self._handle_device = (
                self._handle.device if self._in_graph_mode else
                context.get_default_context().device_name)
            self._shape = initial_value.get_shape()
        # pylint: enable=protected-access

        # Or get the initial value from a Tensor or Python object.
        else:
          with ops.name_scope("Initializer"):
            initial_value = ops.convert_to_tensor(
                initial_value, name="initial_value", dtype=dtype)
          # pylint: disable=protected-access
          if (self._in_graph_mode and initial_value is not None and
              initial_value.op._get_control_flow_context() is not None):
            raise ValueError(
                "Initializer for variable %s is from inside a control-flow "
                "construct, such as a loop or conditional. When creating a "
                "variable inside a loop or conditional, use a lambda as the "
                "initializer." % name)
          # pylint: enable=protected-access
          self._handle = _eager_safe_variable_handle(
              shape=initial_value.get_shape(),
              dtype=initial_value.dtype.base_dtype,
              shared_name=handle_name,
              name=name,
              graph_mode=self._in_graph_mode)
          self._handle_device = (self._handle.device if self._in_graph_mode else
                                 context.get_default_context().device_name)
          self._shape = initial_value.get_shape()

        self._initial_value = initial_value if self._in_graph_mode else None
        self._handle_name = handle_name + ":0"
        self._dtype = initial_value.dtype.base_dtype
        self._constraint = constraint

        if self._in_graph_mode:
          with ops.name_scope("IsInitialized"):
            self._is_initialized_op = (
                gen_resource_variable_ops.var_is_initialized_op(self._handle))
          if initial_value is not None:
            with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
              self._initializer_op = (
                  gen_resource_variable_ops.assign_variable_op(
                      self._handle,
                      self._try_guard_against_uninitialized_dependencies(
                          initial_value),
                      name=n))
          with ops.name_scope("Read"), ops.colocate_with(self._handle):
            # Manually assign reads to the handle's device to avoid log
            # messages.
            with ops.device(self._handle_device):
              value = self._read_variable_op()
            self._graph_element = value
            if caching_device is not None:
              # Variables may be created in a tf.device() or ops.colocate_with()
              # context. At the same time, users would expect caching device to
              # be independent of this context, and/or would not expect the
              # current device context to be merged with the caching device
              # spec.  Therefore we reset the colocation stack before creating
              # the cached value. Note that resetting the colocation stack will
              # also reset the device stack.
              with ops.colocate_with(None, ignore_existing=True):
                with ops.device(caching_device):
                  self._cached_value = array_ops.identity(value)
            else:
              self._cached_value = None
        else:
          gen_resource_variable_ops.assign_variable_op(self._handle,
                                                       initial_value)
          self._is_initialized_op = None
          self._initializer_op = None
          self._graph_element = None
          if caching_device:
            with ops.device(caching_device):
              self._cached_value = self._read_variable_op()
          else:
            self._cached_value = None
        if context.in_graph_mode():
          ops.add_to_collections(collections, self)
        elif ops.GraphKeys.GLOBAL_STEP in collections:
          ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self)

    if not self._in_graph_mode:
      # After the handle has been created, set up a way to clean it up when
      # executing eagerly. We'll hold the only reference to the deleter, so that
      # when this object is garbage collected the deleter will be too. This
      # means ResourceVariables can be part of reference cycles without those
      # cycles being uncollectable, and means that no __del__ will be defined at
      # all in graph mode.
      self._handle_deleter = EagerResourceDeleter(
          handle=self._handle, handle_device=self._handle_device)
开发者ID:keithc61,项目名称:tensorflow,代码行数:101,代码来源:resource_variable_ops.py


示例15: _init_from_args


#.........这里部分代码省略.........
        which is the initial value for the Variable. The initial value must have
        a shape specified unless `validate_shape` is set to False. Can also be a
        callable with no argument that returns the initial value when called.
        (Note that initializer functions from init_ops.py must first be bound
         to a shape before being used here.)
      trainable: If `True`, the default, also adds the variable to the graph
        collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
        the default list of variables to use by the `Optimizer` classes.
      collections: List of graph collections keys. The new variable is added to
        these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
      validate_shape: Ignored. Provided for compatibility with tf.Variable.
      caching_device: Optional device string or function describing where the
        Variable should be cached for reading.  Defaults to the Variable's
        device.  If not `None`, caches on another device.  Typical use is to
        cache on the device where the Ops using the Variable reside, to
        deduplicate copying through `Switch` and other conditional statements.
      name: Optional name for the variable. Defaults to `'Variable'` and gets
        uniquified automatically.
      dtype: If set, initial_value will be converted to the given type.
        If None, either the datatype will be kept (if initial_value is
       a Tensor) or float32 will be used (if it is a Python object convertible
       to a Tensor).

    Raises:
      ValueError: If the initial value is not specified, or does not have a
        shape and `validate_shape` is `True`.
    """
    if initial_value is None:
      raise ValueError("initial_value must be specified.")
    init_from_fn = callable(initial_value)

    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    if not isinstance(collections, (list, tuple, set)):
      raise ValueError(
          "collections argument to Variable constructor must be a list, tuple, "
          "or set. Got %s of type %s" % (collections, type(collections)))
    if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
      collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
    self._save_slice_info = None
    with ops.control_dependencies(None):
      with ops.name_scope(name, "Variable", [] if init_from_fn else
                          [initial_value]) as name:
        # pylint: disable=protected-access
        true_name = ops._name_from_scope_name(name)
        if init_from_fn:
          # Use attr_scope and device(None) to simulate the behavior of
          # colocate_with when the variable we want to colocate with doesn't
          # yet exist.
          attr = attr_value_pb2.AttrValue(
              list=attr_value_pb2.AttrValue.ListValue(
                  s=[compat.as_bytes("loc:@%s" % true_name)]))
          with ops.get_default_graph()._attr_scope({"_class": attr}):
            with ops.name_scope("Initializer"), ops.device(None):
              self._initial_value = ops.convert_to_tensor(
                  initial_value(), name="initial_value", dtype=dtype)
            self._handle = gen_resource_variable_ops.var_handle_op(
                shape=self._initial_value.get_shape(),
                dtype=self._initial_value.dtype.base_dtype,
                shared_name=true_name, name=name)
        # pylint: enable=protected-access

        # Or get the initial value from a Tensor or Python object.
        else:
          self._initial_value = ops.convert_to_tensor(
              initial_value, name="initial_value", dtype=dtype)
          self._handle = gen_resource_variable_ops.var_handle_op(
              shape=self._initial_value.get_shape(),
              dtype=self._initial_value.dtype.base_dtype,
              shared_name=true_name, name=name)

        self._dtype = self._initial_value.dtype.base_dtype

        with ops.name_scope("IsInitialized"):
          self._is_initialized_op = (
              gen_resource_variable_ops.var_is_initialized_op(self._handle))
        if initial_value is not None:
          with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
            self._initialize_op = gen_resource_variable_ops.assign_variable_op(
                self._handle, self._initial_value, name=n)
        with ops.name_scope("Read"), ops.colocate_with(self._handle):
          # Manually assign reads to the handle's device to avoid log messages.
          with ops.device(self._handle.device):
            value = gen_resource_variable_ops.read_variable_op(
                self._handle, dtype=self._dtype)
          self._graph_element = value
          if caching_device is not None:
            # Variables may be created in a tf.device() or ops.colocate_with()
            # context. At the same time, users would expect caching device to be
            # independent of this context, and/or would not expect the current
            # device context to be merged with the caching device spec.
            # Therefore we reset the colocation stack before creating the cached
            # value. Note that resetting the colocation stack will also reset
            # the device stack.
            with ops.colocate_with(None, ignore_existing=True):
              with ops.device(caching_device):
                self._cached_value = array_ops.identity(value)
          else:
            self._cached_value = None
          ops.add_to_collections(collections, self)
开发者ID:chenjun0210,项目名称:tensorflow,代码行数:101,代码来源:resource_variable_ops.py


示例16: execute

  def execute(self, fn, *args, **kwargs):
    """Execute function `fn(*args, **kwargs)` inside the CriticalSection.

    Args:
      fn: The function to execute.  Must return at least one tensor.
      *args: Additional positional arguments to `fn`.
      **kwargs: Additional keyword arguments to `fn`.
        Several keywords are reserved for `execute`.  These are:

        - name; The name to use when creating the execute operation.
        - exclusive_resource_access; Whether the resources required by
          `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
          You may want to set this to `False` if you will be accessing a
          resource in read-only mode in two different CriticalSections.

    Returns:
      The tensors returned from `fn(*args, **kwargs)`.

    Raises:
      ValueError: If `fn` attempts to use this `CriticalSection` in any nested
        way.
      ValueError: If `exclusive_resource_access` is not provided (is `True`) and
        another `CriticalSection` has an execution requesting the same
        resources as in `*args`, `**kwargs`, and any additionaly captured
        inputs in `fn`.  Note, even if `exclusive_resource_access` is `True`,
        if another execution in another `CriticalSection` was created without
        `exclusive_resource_access=True`, a `ValueError` will be raised.
    """
    name = kwargs.pop("name", None)
    exclusive_resource_access = kwargs.pop("exclusive_resource_access", True)

    args = nest.map_structure(ops.convert_to_tensor, args)
    with ops.name_scope(name, "critical_section_execute", []):
      fn_op = function.make_defun_op(fn, *args, **kwargs)
      flat_dtypes = nest.flatten(fn_op.output_dtypes)
      flat_shapes = nest.flatten(fn_op.output_shapes)
      all_inputs = nest.flatten(args) + fn_op.captured_inputs
      if self._handle in all_inputs:
        raise ValueError("The function fn attempts to access the "
                         "CriticalSection in which it would be running.  This "
                         "is illegal and would cause deadlocks.  "
                         "CriticalSection: %s." % self._handle)

      if context.in_graph_mode():
        # Collections and op introspection does not work in eager
        # mode.  This is generally ok; since eager mode (as of
        # writing) executes sequentially anyway.
        all_input_resources = [
            x for x in all_inputs if x.dtype == dtypes.resource]
        for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS):
          if sg.op.inputs[0].name == self._handle.name:
            # Other executions in the same critical section are allowed.
            continue
          if not (exclusive_resource_access or sg.exclusive_resource_access):
            # Neither execution requested exclusive access.
            continue
          sg_input_names = [y.name for y in sg.op.inputs[1:]]
          for res in all_input_resources:
            if res.name in sg_input_names:
              raise ValueError(
                  "This execution would access resource %s; but either this "
                  "execution (CriticalSection: %s) or Execution '%s' "
                  "(CriticalSection: %s) requested exclusive resource access "
                  "of this resource for their critical section.  Did you mean "
                  "to call execute with keyword argument "
                  "exclusive_resource_access=False?"
                  % (res.name,
                     self.name,
                     sg.op.name,
                     sg.op.inputs[0].op.name))

      flat_outputs = gen_resource_variable_ops.execute_in_critical_section(
          critical_section=self._handle,
          arguments=all_inputs,
          f=fn_op,
          output_types=flat_dtypes,
          output_shapes=flat_shapes)

      if context.in_graph_mode():
        if isinstance(flat_outputs, ops.Operation):
          flat_outputs = [flat_outputs]
        op = (flat_outputs[0].op if isinstance(flat_outputs[0], ops.Tensor)
              else flat_outputs[0])
        signature = _ExecutionSignature(
            op=op,
            exclusive_resource_access=exclusive_resource_access)
        ops.add_to_collections(
            CRITICAL_SECTION_EXECUTIONS, signature)

      return (flat_outputs[0]
              if (len(flat_outputs) == 1
                  and isinstance(flat_outputs[0], ops.Operation))
              else nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs))
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:93,代码来源:critical_section_ops.py


示例17: execute


#.........这里部分代码省略.........
    calling `fn` in the critical section, create a lambda:

    ```python
    critical_section.execute(lambda: fn(*my_args, **my_kwargs))
    ```

    Args:
      fn: The function to execute.  Must return at least one tensor.
      exclusive_resource_access: Whether the resources required by
        `fn` should be exclusive to this `CriticalSection`.  Default: `True`.
        You may want to set this to `False` if you will be accessing a
        resource in read-only mode in two different CriticalSections.
      name: The name to use when creating the execute operation.

    Returns:
      The tensors returned from `fn()`.

    Raises:
      ValueError: If `fn` attempts to lock this `CriticalSection` in any nested
        or lazy way that may cause a deadlock.
      ValueError: If `exclusive_resource_access == True` and
        another `CriticalSection` has an execution requesting the same
        resources as `fn``.  Note, even if `exclusive_resource_access` is
        `True`, if another execution in another `CriticalSection` was created
        without `exclusive_resource_access=True`, a `ValueError` will be raised.
    """
    with ops.name_scope(name, "critical_section_execute", []):

      # Ensure that mutex locking only happens *after* all args and
      # kwargs have been executed.  This avoids certain types of deadlocks.
      lock = gen_resource_variable_ops.mutex_lock(self._handle)

      if not context.executing_eagerly():
        # NOTE(ebrevdo): This is to ensure we don't pick up spurious
        # Operations created by other threads.
        with ops.get_default_graph()._lock:  # pylint: disable=protected-access
          existing_ops = ops.get_default_graph().get_operations()
          with ops.control_dependencies([lock]):
            r = fn()
          # TODO(ebrevdo): If creating critical sections in a python loop, this
          # makes graph creation time quadratic.  Revisit if this
          # becomes a problem.
          created_ops = (set(ops.get_default_graph().get_operations())
                         .difference(existing_ops))
      else:
        with ops.control_dependencies([lock]):
          r = fn()

      if not context.executing_eagerly():
        self._add_control_dependencies_to_lock(created_ops, lock.op)

        # captured_resources is a list of resources that are directly
        # accessed only by ops created during fn(), not by any
        # ancestors of those ops in the graph.
        captured_resources = set([
            input_ for op in created_ops
            for input_ in op.inputs
            if input_.dtype == dtypes.resource
        ])

        # NOTE(ebrevdo): The only time self._is_self_handle() is True
        # in this call is if one of the recently created ops, within
        # the execute(), themselves attempt to access the
        # CriticalSection.  This will cause a deadlock.
        if any(self._is_self_handle(x) for x in captured_resources):
          raise ValueError("The function fn attempts to directly access the "
                           "CriticalSection in which it would be running.  "
                           "This is illegal and would cause deadlocks.")

        self._check_multiple_access_to_resources(
            captured_resources, exclusive_resource_access)

      r_flat = [_identity(x) for x in nest.flatten(r)]

      with ops.control_dependencies(r_flat):
        # The identity must run on the same machine as self._handle
        with ops.colocate_with(self._handle):
          # Do not use array_ops.identity as there are special
          # optimizations within TensorFlow which seem to elide it
          # even when optimizations are disabled(!).
          ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock(
              lock)

        # Make sure that if any element of r is accessed, all of
        # them are executed together.
        r = nest.pack_sequence_as(r, control_flow_ops.tuple(nest.flatten(r)))

      with ops.control_dependencies([ensure_lock_exists]):
        outputs = nest.map_structure(_identity, r)

      if not context.executing_eagerly():
        signature = _ExecutionSignature(
            op=lock.op,
            handle=self._handle,
            resources=list(captured_resources),
            exclusive_resource_access=exclusive_resource_access)
        ops.add_to_collections(
            CRITICAL_SECTION_EXEC 

鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python ops.colocate_with函数代码示例发布时间:2022-05-27
下一篇:
Python ops.add_to_collection函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap