• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python tape.record_operation函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.eager.tape.record_operation函数的典型用法代码示例。如果您正苦于以下问题:Python record_operation函数的具体用法?Python record_operation怎么用?Python record_operation使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了record_operation函数的16个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: capture_value

def capture_value(tensor_map, value, dtype, name):
  """Capture a value from outside the function, to pass in as an extra arg."""
  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes_module.resource:
      handle_data = value._handle_data  # pylint: disable=protected-access
      captured_value._handle_data = handle_data  # pylint: disable=protected-access
      if handle_data is not None and handle_data.is_set:
        # Ensure that shapes and dtypes are propagated.
        shapes, types = zip(*[(pair.shape, pair.dtype)
                              for pair in handle_data.shape_and_type])
        ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
        shapes = [[d.size for d in s.dim]
                  if not s.unknown_rank else None for s in shapes]
        with errors.raise_exception_on_not_ok_status() as status:
          pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
              captured_value._op._graph._c_graph,  # pylint: disable=protected-access
              captured_value._as_tf_output(),  # pylint: disable=protected-access
              shapes,
              ranks,
              types,
              status)

    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  tape.record_operation("captured_value", [captured_value], [value],
                        lambda x: [x])
  return captured_value
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:31,代码来源:function.py


示例2: _convert_to_graph_tensor

def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
  """Captures a Tensor while building a graph mode function.

  Arguments:
    value: A Tensor object.
    dtype: The datatype of the value produced by the node in the graph.
    name:  Name of the node in the graph.
    as_ref: Ignored (required by register_tensor_conversion_function).

  Returns:
    Returns a constant (the current value of the tensor) if capturing
    is not enabled. A placeholder which will have the value of the
    tensor at runtime otherwise.
  """
  if context.in_eager_mode():
    return value
  _ = as_ref
  tensor_map = _scoped_captures.tensors
  if tensor_map is None:
    # Capturing is not enabled.
    return constant_op.constant(value.numpy())
  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes.resource:
      captured_value._handle_data = value._handle_data  # pylint: disable=protected-access
    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  tape.record_operation("captured_value", [captured_value], [value], [],
                        lambda x: x)
  return captured_value
开发者ID:Mazecreator,项目名称:tensorflow,代码行数:33,代码来源:function.py


示例3: _capture_helper

 def _capture_helper(self, tensor, name):
   captured_tensor = self.captures.get(tensor, None)
   if captured_tensor is None:
     captured_tensor = _create_substitute_placeholder(tensor, name=name,
                                                      dtype=tensor.dtype)
     self.captures[tensor] = captured_tensor
     self.inputs.append(captured_tensor)
   tape.record_operation("captured_value", [captured_tensor], [tensor],
                         lambda x: [x])
   return captured_tensor
开发者ID:rmlarsen,项目名称:tensorflow,代码行数:10,代码来源:func_graph.py


示例4: _record_gradient

def _record_gradient(op_name, inputs, attrs, results, ctx, name):
  """Records gradients for a TensorFlow operation.

  Args:
    op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
      execute.
    inputs: A flat list of Tensor object inputs to the operation.
    attrs: A tuple with alternating string attr names and attr values for this
      operation.
    results: The results of the operation (as a flat list).
    ctx: The value of context.context().
    name: Customized name for the operation.

  Returns:
    A list of maybe-wrapped results. Either Tensors or TensorNodes.

  Raises:
    An exception on error.
  """
  if not tape.could_possibly_record():
    return

  if op_name in _ops_which_dont_need_outputs:
    op_outputs = None
  else:
    # TODO(apassos) this line creates a weak circular reference where the
    # backprop function keeps an output alive which in turn keeps the tape entry
    # alive which keeps the backprop function alive. Figure out how to break
    # this up without breaking second derivatives of ops like Exp whose
    # gradients depend only on the outputs.
    op_outputs = results

  if op_name in _ops_which_dont_need_inputs:
    op_inputs = None
  else:
    op_inputs = inputs

  num_inputs = len(inputs)

  def grad_fn(*orig_outputs):
    """Generated gradient function."""
    result = _magic_gradient_function(op_name, attrs, num_inputs,
                                      op_inputs, op_outputs, orig_outputs)
    if _tracing:
      print("Gradient for", (name if name else op_name), "inputs", op_inputs,
            "output_grads", orig_outputs, "gradients", result)
    return result

  inputs = [ops.internal_convert_to_tensor(x, ctx=ctx) for x in inputs]
  tape.record_operation(op_name, results, inputs, [], grad_fn)
  if _tracing:
    print("Computed op", (name if name else op_name), "inputs", inputs,
          "outputs", results)
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:53,代码来源:backprop.py


示例5: capture_value

def capture_value(tensor_map, value, dtype, name):
  """Capture a value from outside the function, to pass in as an extra arg."""
  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes.resource:
      captured_value._handle_data = value._handle_data  # pylint: disable=protected-access
    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  tape.record_operation("captured_value", [captured_value], [value],
                        lambda x: [x])
  return captured_value
开发者ID:SylChan,项目名称:tensorflow,代码行数:14,代码来源:function.py


示例6: decorated

  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    if context.in_graph_mode():
      if kwargs:
        raise ValueError(
            "custom_gradient in graph mode doesn't support keyword arguments.")
      name = "CustomGradient-%s" % tf_ops.uid()
      args = [tf_ops.convert_to_tensor(x) for x in args]
      result, grad_fn = f(*args)
      flat_result = nest.flatten(result)
      all_tensors = flat_result + args

      @tf_ops.RegisterGradient(name)
      def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)]))
        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        return ([None] * len(flat_result)) + gradients

      with tf_ops.get_default_graph().gradient_override_map(
          {"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)
      return nest.pack_sequence_as(
          structure=result, flat_sequence=all_tensors[:len(flat_result)])

    input_tensors = [x for x in args
                     if isinstance(x, tf_ops.Tensor)]

    with tape.stop_recording():
      result, grad_fn = f(*args, **kwargs)

    # TODO(apassos): naive uses of custom_gradient will not get the correct
    # second derivative this way if they capture any output tensors. Change the
    # signature of custom_gradient.
    def actual_grad_fn(*outputs):
      return grad_fn(*outputs)

    flat_result = nest.flatten(result)
    tape.record_operation(
        f.__name__,
        flat_result,
        input_tensors,
        [],
        actual_grad_fn)
    flat_result = list(flat_result)
    return result
开发者ID:Mazecreator,项目名称:tensorflow,代码行数:47,代码来源:custom_gradient.py


示例7: decorated

  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    if context.in_graph_mode():
      if kwargs:
        raise ValueError(
            "custom_gradient in graph mode doesn't support keyword arguments.")
      name = "CustomGradient-%s" % tf_ops.uid()
      args = [tf_ops.convert_to_tensor(x) for x in args]
      result, grad_fn = f(*args)
      flat_result = nest.flatten(result)
      all_tensors = flat_result + args

      @tf_ops.RegisterGradient(name)
      def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
        gradients = nest.flatten(grad_fn(*result_grads[:len(flat_result)]))
        # Need to return one value per input to the IdentityN, so pad the
        # gradients of the inputs of the custom_gradient function with the
        # gradients of the outputs as well.
        return ([None] * len(flat_result)) + gradients

      with tf_ops.get_default_graph().gradient_override_map(
          {"IdentityN": name}):
        all_tensors = array_ops.identity_n(all_tensors)
      return nest.pack_sequence_as(
          structure=result, flat_sequence=all_tensors[:len(flat_result)])

    input_tensors = [tf_ops.convert_to_tensor(x) for x in args]

    result, grad_fn = f(*args, **kwargs)
    flat_result = nest.flatten(result)
    # TODO(apassos) consider removing the identity below.
    flat_result = [gen_array_ops.identity(x) for x in flat_result]

    def actual_grad_fn(*outputs):
      return nest.flatten(grad_fn(*outputs))

    tape.record_operation(
        f.__name__,
        flat_result,
        input_tensors,
        actual_grad_fn)
    flat_result = list(flat_result)
    return nest.pack_sequence_as(result, flat_result)
开发者ID:neuroradiology,项目名称:tensorflow,代码行数:43,代码来源:custom_gradient.py


示例8: _backprop_call

  def _backprop_call(self, args):
    """Calls the wrapped function and records the result on a tape."""
    all_args = args + self._extra_inputs
    signature = self._forward_fdef.definition.signature
    ctx = context.context()
    if ctx.in_graph_mode():
      g = ops.get_default_graph()
      g._add_function(self._forward_fdef)  # pylint: disable=protected-access
      def make_tensor(x):
        if isinstance(x, ops.Tensor):
          return x
        return ops.internal_convert_to_tensor(x, ctx=ctx)
      op = g.create_op(
          signature.name, [make_tensor(x) for x in all_args],
          [dtypes.DType(x.type) for x in signature.output_arg],
          op_def=signature,
          name="FunctionCall",
          compute_shapes=False)
      outputs = op.outputs
      outputs = [outputs] if isinstance(
          outputs, (ops.Tensor, type(None))) else list(outputs)
      for i, s in enumerate(self._output_shapes):
        outputs[i].set_shape(s)
    else:
      outputs = execute.execute(
          str(signature.name),
          num_outputs=len(signature.output_arg),
          inputs=all_args,
          attrs=None,
          ctx=ctx)
    real_outputs = outputs[:len(self._returns)]
    side_outputs = outputs[len(self._returns):]

    def backward_function(*args):
      return self._backward_function(*(list(args) + side_outputs))

    tape.record_operation(
        signature.name,
        real_outputs,
        (args + self._extra_inputs),
        backward_function)

    return self._build_call_outputs(real_outputs)
开发者ID:SylChan,项目名称:tensorflow,代码行数:43,代码来源:function.py


示例9: _eager_mode_decorator

def _eager_mode_decorator(f, *args, **kwargs):
  """Implement custom gradient decorator for eager mode."""
  with backprop.GradientTape() as tape:
    result, grad_fn = f(*args, **kwargs)
  all_inputs = list(args) + list(kwargs.values())
  # The variables that grad_fn needs to return gradients for are the set of
  # variables used that are *not* part of the inputs.
  variables = [v for v in set(tape.watched_variables()) if v not in all_inputs]
  grad_argspec = tf_inspect.getfullargspec(grad_fn)
  if (variables and ("variables" not in grad_argspec.args) and
      not grad_argspec.varkw):
    raise TypeError("If using @custom_gradient with a function that "
                    "uses variables, then grad_fn must accept a keyword "
                    "argument 'variables'.")
  flat_result = nest.flatten(result)
  # TODO(apassos) consider removing the identity below.
  flat_result = [gen_array_ops.identity(x) for x in flat_result]

  input_tensors = [ops.convert_to_tensor(x) for x
                   in list(args) + list(variables)]
  arg_count = len(args)
  def actual_grad_fn(*result_grads):
    """Custom grad fn wrapper."""
    if variables:
      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
      if len(variable_grads) != len(variables):
        raise ValueError("Must return gradient for each variable from "
                         "@custom_gradient grad_fn.")
    else:
      input_grads = grad_fn(*result_grads)
      variable_grads = []
    flat_grads = nest.flatten(input_grads)
    if len(flat_grads) != arg_count:
      raise ValueError(
          "custom_gradient function expected to return", arg_count,
          "gradients but returned", len(flat_grads), "instead.")
    return nest.flatten(input_grads) + variable_grads

  tape_lib.record_operation(f.__name__, flat_result, input_tensors,
                            actual_grad_fn)
  flat_result = list(flat_result)
  return nest.pack_sequence_as(result, flat_result)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:42,代码来源:custom_gradient.py


示例10: capture_value

def capture_value(tensor_map, value, dtype, name):
  """Capture a value from outside the function, to pass in as an extra arg."""
  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes_module.resource:
      if ops._USE_C_SHAPES:  # pylint: disable=protected-access
        if isinstance(value, ops.EagerTensor):
          handle_data = value._handle_data  # pylint: disable=protected-access
        else:
          handle_data = resource_variable_ops.get_resource_handle_data(value)
      else:
        handle_data = value._handle_data  # pylint: disable=protected-access
      if handle_data is not None and handle_data.is_set:
        # pylint: disable=protected-access
        if ops._USE_C_SHAPES:
          pywrap_tensorflow.SetResourceHandleShapeAndType(
              captured_value.graph._c_graph, captured_value._as_tf_output(),
              handle_data.SerializeToString())
        else:
          captured_value._handle_data = handle_data
        # pylint: enable=protected-access
        # Ensure that shapes and dtypes are propagated.
        shapes, types = zip(*[(pair.shape, pair.dtype)
                              for pair in handle_data.shape_and_type])
        ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes]
        shapes = [[d.size for d in s.dim]
                  if not s.unknown_rank else None for s in shapes]
        pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper(
            captured_value._op._graph._c_graph,  # pylint: disable=protected-access
            captured_value._as_tf_output(),  # pylint: disable=protected-access
            shapes, ranks, types)

    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  tape.record_operation("captured_value", [captured_value], [value],
                        lambda x: [x])
  return captured_value
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:40,代码来源:function.py


示例11: decorated

  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    input_tensors = [x for x in args
                     if isinstance(x, tf_ops.Tensor)]

    with tape.stop_recording():
      result, grad_fn = f(*args, **kwargs)

    # TODO(apassos): naive uses of custom_gradient will not get the correct
    # second derivative this way if they capture any output tensors. Change the
    # signature of custom_gradient.
    def actual_grad_fn(*outputs):
      return grad_fn(*outputs)

    flat_result = nest.flatten(result)
    tape.record_operation(
        flat_result,
        input_tensors,
        [],
        actual_grad_fn)
    flat_result = list(flat_result)
    return result
开发者ID:1000sprites,项目名称:tensorflow,代码行数:22,代码来源:custom_gradient.py


示例12: decorated

  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    input_tensors = [_watch_value_from_tape(x) for x in args
                     if isinstance(x, (_tensor.Tensor, tf_ops.Tensor))
                     or ag_core.isnode(x)]
    result, grad_fn = f(*args, **kwargs)

    flat_result = nest.flatten(result)
    flat_result = [ag_core.getval(x) for x in flat_result]
    flat_result = tape.record_operation(
        flat_result,
        input_tensors,
        [],
        grad_fn)
    flat_result = list(flat_result)
    return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:16,代码来源:custom_gradient.py


示例13: _record_gradient

def _record_gradient(op_name, inputs, attrs, results, name):
  """Records gradients for a TensorFlow operation.

  Args:
    op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
      execute.
    inputs: A flat list of Tensor object inputs to the operation.
    attrs: A tuple with alternating string attr names and attr values for this
      operation.
    results: The results of the operation (as a flat list).
    name: Customized name for the operation.

  Returns:
    A list of maybe-wrapped results. Either Tensors or TensorNodes.

  Raises:
    An exception on error.
  """
  if not any(ag_core.isnode(x) for x in inputs):
    return results
  num_outputs = len(results)
  if num_outputs == 0:
    return results
  if attrs is not None:
    attrs = tuple(tuple(x) if isinstance(x, list) else x for x in attrs)

  # It is imperative we make a copy of results here as otherwise we create a
  # dependency cycle in the captured function and this can delay garbage
  # collecting of the tensors arbitrarily.
  results_size = len(results) if isinstance(results, (list, tuple)) else 1

  def grad_fn(*orig_outputs):
    """Generated gradient function."""
    tensors = inputs + list(orig_outputs)
    tensors = container_types.make_sequence(tape.EagerList, *tensors)
    result = _magic_gradient_function(op_name, attrs, len(inputs),
                                      num_outputs, *(tensors))
    if _tracing:
      print("Gradient for", (name if name else op_name), "inputs", inputs,
            "output_grads", orig_outputs[results_size:], "gradients", result)
    return result

  results = tape.record_operation(results, inputs, [], grad_fn)
  if _tracing:
    print("Computed op", (name if name else op_name), "inputs", inputs,
          "outputs", results)
  return results
开发者ID:keveman,项目名称:tensorflow,代码行数:47,代码来源:backprop.py


示例14: _backprop_call

  def _backprop_call(self, args):
    """Calls the wrapped function and records the result on a tape."""
    all_args = args + self._extra_inputs
    signature = self._forward_fdef.definition.signature
    if context.in_graph_mode():
      g = ops.get_default_graph()
      g._add_function(self._forward_fdef)  # pylint: disable=protected-access
      unwrapped_args = [ag_core.getval(x) for x in all_args]
      op = g.create_op(
          signature.name, [ops.convert_to_tensor(x) for x in unwrapped_args],
          [dtypes.DType(x.type) for x in signature.output_arg],
          op_def=signature,
          name="FunctionCall",
          compute_shapes=False)
      outputs = op.outputs
      outputs = [outputs] if isinstance(
          outputs, (tensor.Tensor, ops.Tensor, type(None))) else list(outputs)
      for i, s in enumerate(self._output_shapes):
        outputs[i].set_shape(s)
    else:
      outputs = execute.execute(
          signature.name,
          num_outputs=len(signature.output_arg),
          inputs=all_args)
    real_outputs = outputs[:len(self._returns)]
    side_outputs = outputs[len(self._returns):]
    watched_extra_inputs = []
    for t in self._extra_inputs:
      tid = ops.tensor_id(t)
      for t in tape._tape_stack.stack:  # pylint: disable=protected-access
        w = t.value.tensors.get(tid, None)
        if w is not None:
          watched_extra_inputs.append(w)
          break
      else:  # Note: for-else here done on purpose
        watched_extra_inputs.append(t)

    def backward_function_wrapper(*outputs):
      outputs = outputs[len(real_outputs):]
      return self._backward_function(*outputs)
    real_outputs = tape.record_operation(
        real_outputs,
        (args + watched_extra_inputs),
        side_outputs,
        backward_function_wrapper)

    return self._build_call_outputs(self._returns, real_outputs)
开发者ID:keveman,项目名称:tensorflow,代码行数:47,代码来源:function.py


示例15: decorated

  def decorated(*args, **kwargs):
    """Decorated function with custom gradient."""
    input_tensors = [_watch_value_from_tape(x) for x in args
                     if isinstance(x, (_tensor.Tensor, tf_ops.Tensor))
                     or ag_core.isnode(x)]
    result, grad_fn = f(*args, **kwargs)
    result_size = len(result) if isinstance(result, (list, tuple)) else 1

    # TODO(apassos): naive uses of custom_gradient will not get the correct
    # second derivative this way if they capture any output tensors. Change the
    # signature of custom_gradient.
    def actual_grad_fn(*outputs):
      outputs = outputs[result_size:]
      return grad_fn(*outputs)

    flat_result = nest.flatten(result)
    flat_result = [ag_core.getval(x) for x in flat_result]
    flat_result = tape.record_operation(
        flat_result,
        input_tensors,
        [],
        actual_grad_fn)
    flat_result = list(flat_result)
    return nest.pack_sequence_as(structure=result, flat_sequence=flat_result)
开发者ID:solaris33,项目名称:tensorflow,代码行数:24,代码来源:custom_gradient.py


示例16: _graph_mode_decorator

def _graph_mode_decorator(f, *args, **kwargs):
  """Implement custom gradient decorator for graph mode."""
  # TODO(rsepassi): Add support for kwargs
  if kwargs:
    raise ValueError(
        "The custom_gradient decorator currently supports keywords "
        "arguments only when eager execution is enabled.")
  name = "CustomGradient-%s" % ops.uid()
  args = [ops.convert_to_tensor(x) for x in args]

  # Checking global and local variables attempts to ensure that no non-resource
  # Variables are added to the graph.
  current_var_scope = variable_scope.get_variable_scope()
  before_vars = set(current_var_scope.global_variables() +
                    current_var_scope.local_variables())
  with backprop.GradientTape() as tape:
    result, grad_fn = f(*args)
  after_vars = set(current_var_scope.global_variables() +
                   current_var_scope.local_variables())
  new_vars = after_vars - before_vars
  for v in new_vars:
    if not isinstance(v, resource_variable_ops.ResourceVariable):
      raise TypeError(
          "All variables used by a function wrapped with @custom_gradient must "
          "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
          "with `use_resource=False`.")
  # The variables that grad_fn needs to return gradients for are the set of
  # variables used that are *not* part of the inputs.
  variables = list(set(tape.watched_variables()) - set(args))
  grad_argspec = tf_inspect.getfullargspec(grad_fn)
  variables_in_signature = ("variables" in grad_argspec.args or
                            grad_argspec.varkw)
  if variables and not variables_in_signature:
    raise TypeError("If using @custom_gradient with a function that "
                    "uses variables, then grad_fn must accept a keyword "
                    "argument 'variables'.")
  if variables_in_signature and not variables:
    # User seems to intend to use variables but none were captured.
    if not variable_scope.get_variable_scope().use_resource:
      raise TypeError("If using @custom_gradient with a function that "
                      "uses variables, the enclosing variable scope must "
                      "have use_resource=True.")
    else:
      logging.warn("@custom_gradient grad_fn has 'variables' in signature, but "
                   "no ResourceVariables were used on the forward pass.")
  flat_result = nest.flatten(result)
  all_tensors = flat_result + args + variables

  def tape_grad_fn(*result_grads):
    """Custom grad fn wrapper."""
    result_grads = result_grads[:len(flat_result)]
    if variables:
      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
      if len(variable_grads) != len(variables):
        raise ValueError("Must return gradient for each variable from "
                         "@custom_gradient grad_fn.")
    else:
      input_grads = grad_fn(*result_grads)
      variable_grads = []

    # Need to return one value per input to the IdentityN, so pad the
    # gradients of the inputs of the custom_gradient function with the
    # gradients of the outputs as well.
    input_grads = nest.flatten(input_grads)
    return ([None] * len(flat_result)) + input_grads + variable_grads

  @ops.RegisterGradient(name)
  def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
    """Custom grad fn wrapper."""
    return tape_grad_fn(*result_grads)

  original_tensors = all_tensors
  with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
    all_tensors = array_ops.identity_n(all_tensors)
  # Propagate handle data for happier shape inference for resource variables.
  for i, t in enumerate(original_tensors):
    if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
      all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
  tape_lib.record_operation(
      f.__name__, all_tensors, original_tensors, tape_grad_fn)
  for ot, t in zip(original_tensors, all_tensors):
    copy_handle_data(ot, t)
  return nest.pack_sequence_as(
      structure=result, flat_sequence=all_tensors[:len(flat_result)])
开发者ID:terrytangyuan,项目名称:tensorflow,代码行数:84,代码来源:custom_gradient.py



注:本文中的tensorflow.python.eager.tape.record_operation函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python tape.stop_recording函数代码示例发布时间:2022-05-27
下一篇:
Python tape.push_new_tape函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap