• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python types.as_dtype函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.framework.types.as_dtype函数的典型用法代码示例。如果您正苦于以下问题:Python as_dtype函数的具体用法?Python as_dtype怎么用?Python as_dtype使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了as_dtype函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: _DefaultGradYs

def _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops):
  """Fill in default values for grad_ys.

  Args:
    grad_ys: List of gradients, can contain None.
    ys: List of tensors.
    colocate_gradients_with_ops: If True, try colocating gradients with
      the corresponding op.

  Returns:
    A list of gradients to use, without None.

  Raises:
    ValueError: If one of the grad_ys is invalid.
  """
  if len(grad_ys) != len(ys):
    raise ValueError("Passed %d grad_ys for %d ys" % (len(grad_ys), len(ys)))
  grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y")
  for i in xrange(len(grad_ys)):
    grad_y = grad_ys[i]
    y = ys[i]
    if grad_y is None:
      with ops.device(_GetGradsDevice(y.op, colocate_gradients_with_ops)):
        grad_ys[i] = array_ops.fill(array_ops.shape(y),
                                    constant_op.constant(1, dtype=y.dtype))
    else:
      if grad_y.dtype != y.dtype:
        raise ValueError("Y and ys_grad must be of the same type, "
                         "not y: %s, ys_grad: %s " %
                         (types.as_dtype(y.dtype).name,
                          types.as_dtype(grad_y.dtype).name))
  return grad_ys
开发者ID:bradg19,项目名称:tensor,代码行数:32,代码来源:gradients.py


示例2: _SatisfiesTypeConstraint

def _SatisfiesTypeConstraint(dtype, attr_def):
  if attr_def.HasField("allowed_values"):
    allowed_list = attr_def.allowed_values.list.type
    if dtype not in allowed_list:
      raise TypeError(
          "DataType %s for attr '%s' not in list of allowed values: %s" %
          (types_lib.as_dtype(dtype).name, attr_def.name,
           ", ".join(types_lib.as_dtype(x).name for x in allowed_list)))
开发者ID:adeelzaman,项目名称:tensorflow,代码行数:8,代码来源:op_def_library.py


示例3: testAllTypesConvertibleToNumpyDtype

 def testAllTypesConvertibleToNumpyDtype(self):
   for datatype_enum in types_pb2.DataType.values():
     if datatype_enum == types_pb2.DT_INVALID:
       continue
     dtype = types.as_dtype(datatype_enum)
     numpy_dtype = dtype.as_numpy_dtype
     _ = np.empty((1, 1, 1, 1), dtype=numpy_dtype)
     if dtype.base_dtype != types.bfloat16:
       # NOTE(touts): Intentionally no way to feed a DT_BFLOAT16.
       self.assertEqual(
           types.as_dtype(datatype_enum).base_dtype, types.as_dtype(numpy_dtype))
开发者ID:debaratidas1994,项目名称:tensorflow,代码行数:11,代码来源:types_test.py


示例4: testIsInteger

 def testIsInteger(self):
   self.assertEqual(types.as_dtype("int8").is_integer, True)
   self.assertEqual(types.as_dtype("int16").is_integer, True)
   self.assertEqual(types.as_dtype("int32").is_integer, True)
   self.assertEqual(types.as_dtype("int64").is_integer, True)
   self.assertEqual(types.as_dtype("uint8").is_integer, True)
   self.assertEqual(types.as_dtype("complex64").is_integer, False)
   self.assertEqual(types.as_dtype("float").is_integer, False)
   self.assertEqual(types.as_dtype("double").is_integer, False)
   self.assertEqual(types.as_dtype("string").is_integer, False)
   self.assertEqual(types.as_dtype("bool").is_integer, False)
开发者ID:debaratidas1994,项目名称:tensorflow,代码行数:11,代码来源:types_test.py


示例5: testIsFloating

 def testIsFloating(self):
   self.assertEqual(types.as_dtype("int8").is_floating, False)
   self.assertEqual(types.as_dtype("int16").is_floating, False)
   self.assertEqual(types.as_dtype("int32").is_floating, False)
   self.assertEqual(types.as_dtype("int64").is_floating, False)
   self.assertEqual(types.as_dtype("uint8").is_floating, False)
   self.assertEqual(types.as_dtype("complex64").is_floating, False)
   self.assertEqual(types.as_dtype("float32").is_floating, True)
   self.assertEqual(types.as_dtype("float64").is_floating, True)
   self.assertEqual(types.as_dtype("string").is_floating, False)
   self.assertEqual(types.as_dtype("bool").is_floating, False)
开发者ID:hal2001,项目名称:tensorflow,代码行数:11,代码来源:types_test.py


示例6: ones

def ones(shape, dtype=types.float32, name=None):
    """Creates a tensor with all elements set to 1.

  This operation returns a tensor of type `dtype` with shape `shape` and all
  elements set to 1.

  For example:

  ```python
  tf.ones([2, 3], int32) ==> [[1, 1, 1], [1, 1, 1]]
  ```

  Args:
    shape: Either a list of integers, or a 1-D `Tensor` of type `int32`.
    dtype: The type of an element in the resulting `Tensor`.
    name: A name for the operation (optional).

  Returns:
    A `Tensor` with all elements set to 1.
  """
    with ops.op_scope([shape], name, "ones") as name:
        if isinstance(shape, list):
            output = constant(1, shape=shape, dtype=dtype, name=name)
        else:
            shape = ops.convert_to_tensor(shape, name="shape")
            output = fill(shape, constant(1, dtype=dtype), name=name)
    assert output.dtype.base_dtype == types.as_dtype(dtype).base_dtype
    return output
开发者ID:nguyenductung,项目名称:tensorflow,代码行数:28,代码来源:array_ops.py


示例7: __init__

    def __init__(self, key_dtype, value_dtype, default_value, table_ref):
        """Construct a table object from a table reference.

    Args:
      key_dtype:  The table key type.
      value_dtype:  The table value type.
      default_value: The value to use if a key is missing in the table.
      table_ref: The table reference, i.e. the output of the lookup table ops.
    """
        self._key_dtype = types.as_dtype(key_dtype)
        self._value_dtype = types.as_dtype(value_dtype)
        self._shapes = [tensor_shape.TensorShape([1])]
        self._table_ref = table_ref
        self._name = self._table_ref.op.name.split("/")[-1]
        self._default_value = ops.convert_to_tensor(default_value, dtype=self._value_dtype)
        self._default_value.get_shape().merge_with(tensor_shape.scalar())
开发者ID:swapnilashtekar,项目名称:tensorflow,代码行数:16,代码来源:data_flow_ops.py


示例8: _MakeType

def _MakeType(v, attr_def):
  try:
    v = types_lib.as_dtype(v)
  except TypeError:
    raise TypeError("Expected DataType for argument '%s' not %s." %
                    (attr_def.name, repr(v)))
  i = v.as_datatype_enum
  _SatisfiesTypeConstraint(i, attr_def)
  return i
开发者ID:adeelzaman,项目名称:tensorflow,代码行数:9,代码来源:op_def_library.py


示例9: testDTypesHaveUniqueNames

 def testDTypesHaveUniqueNames(self):
   dtypes = []
   names = set()
   for datatype_enum in types_pb2.DataType.values():
     if datatype_enum == types_pb2.DT_INVALID:
       continue
     dtype = types.as_dtype(datatype_enum)
     dtypes.append(dtype)
     names.add(dtype.name)
   self.assertEqual(len(dtypes), len(names))
开发者ID:debaratidas1994,项目名称:tensorflow,代码行数:10,代码来源:types_test.py


示例10: _VerifyGeneratedGradients

def _VerifyGeneratedGradients(grads, op):
  """Verify that gradients are valid in number and type.

  Args:
    grads: List of generated gradients.
    op: Operation for which the gradients where generated.

  Raises:
    ValueError: if the gradients are invalid.
  """
  if len(grads) != len(op.inputs):
    raise ValueError("Num gradients %d generated for op %s do not match num "
                     "inputs %d" % (len(grads), op.node_def, len(op.inputs)))
  for i in xrange(len(grads)):
    grad = grads[i]
    inp = op.inputs[i]
    if grad is not None:
      if not grad.dtype.is_compatible_with(inp.dtype):
        raise ValueError(
            "Gradient type %s generated for op %s does "
            "not match input type %s" %
            (types.as_dtype(grad.dtype).name, op.node_def,
             types.as_dtype(inp.dtype).name))
开发者ID:bradg19,项目名称:tensor,代码行数:23,代码来源:gradients.py


示例11: _ComputeGradient

def _ComputeGradient(x, x_shape, dx, y, y_shape, dy,
                     x_init_value=None, delta=1e-3):
  """Computes the theoretical and numerical jacobian."""
  t = types.as_dtype(x.dtype)
  allowed_types = [types.float32, types.float64]
  assert t.base_dtype in allowed_types, "Don't support type %s for x" % t.name
  t2 = types.as_dtype(y.dtype)
  assert t2.base_dtype in allowed_types, "Don't support type %s for y" % t2.name

  if x_init_value is not None:
    i_shape = list(x_init_value.shape)
    assert(list(x_shape) == i_shape), "x_shape = %s, init_data shape = %s" % (
        x_shape, i_shape)
    x_data = x_init_value
  else:
    if t == types.float32:
      dtype = np.float32
    else:
      dtype = np.float64
    x_data = np.asfarray(np.random.random_sample(x_shape), dtype=dtype)

  jacob_t = _ComputeTheoricalJacobian(x, x_shape, x_data, dy, y_shape, dx)
  jacob_n = _ComputeNumericJacobian(x, x_shape, x_data, y, y_shape, delta)
  return jacob_t, jacob_n
开发者ID:adeelzaman,项目名称:tensorflow,代码行数:24,代码来源:gradient_checker.py


示例12: testMinMax

  def testMinMax(self):
    # make sure min/max evaluates for all data types that have min/max
    for datatype_enum in types_pb2.DataType.values():
      if datatype_enum == types_pb2.DT_INVALID:
        continue
      dtype = types.as_dtype(datatype_enum)
      numpy_dtype = dtype.as_numpy_dtype

      # ignore types for which there are no minimum/maximum (or we cannot
      # compute it, such as for the q* types)
      if (dtype.is_quantized or
          dtype.base_dtype == types.bool or
          dtype.base_dtype == types.string or
          dtype.base_dtype == types.complex64):
        continue

      print("%s: %s - %s" % (dtype, dtype.min, dtype.max))

      # check some values that are known
      if numpy_dtype == np.bool_:
        self.assertEquals(dtype.min, 0)
        self.assertEquals(dtype.max, 1)
      if numpy_dtype == np.int8:
        self.assertEquals(dtype.min, -128)
        self.assertEquals(dtype.max, 127)
      if numpy_dtype == np.int16:
        self.assertEquals(dtype.min, -32768)
        self.assertEquals(dtype.max, 32767)
      if numpy_dtype == np.int32:
        self.assertEquals(dtype.min, -2147483648)
        self.assertEquals(dtype.max, 2147483647)
      if numpy_dtype == np.int64:
        self.assertEquals(dtype.min, -9223372036854775808)
        self.assertEquals(dtype.max, 9223372036854775807)
      if numpy_dtype == np.uint8:
        self.assertEquals(dtype.min, 0)
        self.assertEquals(dtype.max, 255)
      if numpy_dtype == np.uint16:
        self.assertEquals(dtype.min, 0)
        self.assertEquals(dtype.max, 4294967295)
      if numpy_dtype == np.uint32:
        self.assertEquals(dtype.min, 0)
        self.assertEquals(dtype.max, 18446744073709551615)
      if numpy_dtype in (np.float16, np.float32, np.float64):
        self.assertEquals(dtype.min, np.finfo(numpy_dtype).min)
        self.assertEquals(dtype.max, np.finfo(numpy_dtype).max)
开发者ID:debaratidas1994,项目名称:tensorflow,代码行数:46,代码来源:types_test.py


示例13: __init__

  def __init__(self, op, value_index, dtype):
    """Creates a new `Tensor`.

    Args:
      op: An `Operation`. `Operation` that computes this tensor.
      value_index: An `int`. Index of the operation's endpoint that produces
        this tensor.
      dtype: A `types.DType`. Type of data stored in this tensor.

    Raises:
      TypeError: If the op is not an `Operation`.
    """
    if not isinstance(op, Operation):
      raise TypeError("op needs to be an Operation: %s" % op)
    self._op = op
    self._value_index = value_index
    self._dtype = types.as_dtype(dtype)
    self._shape = tensor_shape.unknown_shape()
    # List of operations that use this Tensor as input.  We maintain this list
    # to easily navigate a computation graph.
    self._consumers = []
开发者ID:iwannatoa,项目名称:tensorflow,代码行数:21,代码来源:ops.py


示例14: _restore_slice

def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type,
                   name="restore_slice", preferred_shard=-1):
  """Restore a tensor slice from a set of files with a given pattern.

  Example usage:
    RestoreSlice("/foo/bar-?????-of-?????", "w", "10 10 0,2:-", DT_FLOAT)

  Args:
    file_pattern: the file pattern used to match a set of checkpoint files.
    tensor_name: the name of the tensor to restore.
    shape_and_slice: the shape-and-slice spec of the slice.
    tensor_type: the type of the tensor to restore.
    name: string.  Optional name for the op.
    preferred_shard: Int. Optional shard to open first in the checkpoint file.

  Returns:
    A tensor of type "tensor_type".
  """
  base_type = types.as_dtype(tensor_type).base_dtype
  return gen_io_ops._restore_slice(
      file_pattern, tensor_name, shape_and_slice, base_type,
      preferred_shard, name=name)
开发者ID:ray2020,项目名称:tensorflow,代码行数:22,代码来源:io_ops.py


示例15: testStringConversion

 def testStringConversion(self):
   self.assertIs(types.float32, types.as_dtype("float32"))
   self.assertIs(types.float64, types.as_dtype("float64"))
   self.assertIs(types.int32, types.as_dtype("int32"))
   self.assertIs(types.uint8, types.as_dtype("uint8"))
   self.assertIs(types.int16, types.as_dtype("int16"))
   self.assertIs(types.int8, types.as_dtype("int8"))
   self.assertIs(types.string, types.as_dtype("string"))
   self.assertIs(types.complex64, types.as_dtype("complex64"))
   self.assertIs(types.int64, types.as_dtype("int64"))
   self.assertIs(types.bool, types.as_dtype("bool"))
   self.assertIs(types.qint8, types.as_dtype("qint8"))
   self.assertIs(types.quint8, types.as_dtype("quint8"))
   self.assertIs(types.qint32, types.as_dtype("qint32"))
   self.assertIs(types.bfloat16, types.as_dtype("bfloat16"))
   self.assertIs(types.float32_ref, types.as_dtype("float32_ref"))
   self.assertIs(types.float64_ref, types.as_dtype("float64_ref"))
   self.assertIs(types.int32_ref, types.as_dtype("int32_ref"))
   self.assertIs(types.uint8_ref, types.as_dtype("uint8_ref"))
   self.assertIs(types.int16_ref, types.as_dtype("int16_ref"))
   self.assertIs(types.int8_ref, types.as_dtype("int8_ref"))
   self.assertIs(types.string_ref, types.as_dtype("string_ref"))
   self.assertIs(types.complex64_ref, types.as_dtype("complex64_ref"))
   self.assertIs(types.int64_ref, types.as_dtype("int64_ref"))
   self.assertIs(types.bool_ref, types.as_dtype("bool_ref"))
   self.assertIs(types.qint8_ref, types.as_dtype("qint8_ref"))
   self.assertIs(types.quint8_ref, types.as_dtype("quint8_ref"))
   self.assertIs(types.qint32_ref, types.as_dtype("qint32_ref"))
   self.assertIs(types.bfloat16_ref, types.as_dtype("bfloat16_ref"))
   with self.assertRaises(TypeError):
     types.as_dtype("not_a_type")
开发者ID:debaratidas1994,项目名称:tensorflow,代码行数:31,代码来源:types_test.py


示例16: apply_op


#.........这里部分代码省略.........
        # * Convert values to Tensors if it contains constants.
        # * Verify that values is a list if that matches the input_arg's
        #   type.
        # * If the input_arg's type is determined by attrs, either set
        #   those attrs and validate those attr values are legal (if
        #   they have not yet been set) or validate the input matches
        #   the type indicated by the attrs (if they have already been
        #   inferred via an earlier input).
        # * If the input_arg has an explicit type, make sure the input
        #   conforms.

        if _IsListParameter(input_arg):
          if not _IsListValue(values):
            raise TypeError(
                "Expected list for '%s' argument to '%s' Op, not %s." %
                (input_name, op_type_name, values))
          # In cases where we expect all elements of the list to have the
          # same dtype, try to cast non-Tensor elements to that type.
          dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.number_attr:
            if input_arg.type_attr in attrs:
              dtype = attrs[input_arg.type_attr]
            else:
              for t in values:
                if isinstance(t, ops.Tensor):
                  dtype = t.dtype
                  break

          try:
            values = ops.convert_n_to_tensor_or_indexed_slices(
                values, name=input_arg.name,
                dtype=types_lib.as_dtype(dtype).base_dtype if dtype else None)
          except (TypeError, ValueError):
            assert dtype is not None, "Should not fail if dtype is None"
            assert input_arg.number_attr, "Should be number_attr case"
            # What types does the conversion function think values have?
            values = ops.convert_n_to_tensor_or_indexed_slices(values)
            observed = ", ".join(v.dtype.base_dtype.name for v in values)

            prefix = (
                "Tensors in list passed to '%s' of '%s' Op have types [%s]" %
                (input_name, op_type_name, observed))
            if input_arg.type != types_pb2.DT_INVALID:
              raise TypeError("%s that do not match expected type %s." %
                              (prefix, types_lib.as_dtype(dtype).name))
            elif input_arg.type_attr in attrs:
              raise TypeError("%s that do not match type %s inferred from "
                              "earlier arguments." %
                              (prefix, types_lib.as_dtype(dtype).name))
            else:
              raise TypeError("%s that don't all match." % prefix)

          types = [x.dtype for x in values]
          inputs.extend(values)
        else:
          # In cases where we have an expected type, try to convert non-Tensor
          # arguments to that type.
          dtype = None
          if input_arg.type != types_pb2.DT_INVALID:
            dtype = input_arg.type
          elif input_arg.type_attr in attrs:
            dtype = attrs[input_arg.type_attr]

          try:
开发者ID:adeelzaman,项目名称:tensorflow,代码行数:67,代码来源:op_def_library.py


示例17: gradients


#.........这里部分代码省略.........

    # Initialize the pending count for ops in the connected subgraph from ys
    # to the xs.
    to_ops = [t.op for t in ys]
    from_ops = [t.op for t in xs]
    pending_count, has_control_flow = _PendingCount(
        ops.get_default_graph(), to_ops, from_ops)

    # Iterate over the collected ops.
    #
    # grads: op => list of gradients received on each output endpoint of the
    # op.  The gradients for each endpoint are initially collected as a list.
    # When it is time to call the op's gradient function, for each endpoint we
    # aggregate the list of received gradients into a Add() Operation if there
    # is more than one.
    grads = {}

    # Add the initial gradients for the ys.
    for y, grad_y in zip(ys, grad_ys):
      _SetGrad(grads, y, grad_y)

    # Initialize queue with to_ops.
    queue = collections.deque()
    # Add the ops in 'to_ops' into the queue.
    to_ops_set = set()
    for op in to_ops:
      if op._id not in to_ops_set:
        to_ops_set.add(op._id)
        queue.append(op)
    # The set of 'from_ops'.
    stop_ops = _StopOps(from_ops, pending_count)
    while queue:
      # generate gradient subgraph for op.
      op = queue.popleft()
      with ops.device(_GetGradsDevice(op, colocate_gradients_with_ops)):
        if has_control_flow:
          control_flow_ops.EnterGradWhileContext(op)
        out_grads = _AggregatedGrads(grads, op, has_control_flow,
                                     aggregation_method)
        grad_fn = None
        if any(out_grads) and op._id not in stop_ops:
          # A grad_fn must be defined, either as a function or as None
          # for ops that do not have gradients.
          try:
            grad_fn = ops.get_gradient_function(op)
          except LookupError:
            raise LookupError(
                "No gradient defined for operation '%s' (op type: %s)" %
                (op.name, op.type))
        if grad_fn and any(out_grads):
          # NOTE: If _AggregatedGrads didn't compute a value for the i'th
          # output, it means that the cost does not depend on output[i],
          # therefore dC/doutput[i] is 0.
          for i, out_grad in enumerate(out_grads):
            if (not out_grad
                and types.as_dtype(op.outputs[i].dtype).base_dtype in (
                    types.float32, types.float64)):
              # Only floating-point outputs get a zero gradient. Gradient
              # functions should ignore the gradient for other outputs.
              out_grads[i] = array_ops.zeros_like(op.outputs[i])
          with ops.name_scope(op.name + "_grad"):
            # pylint: disable=protected-access
            with ops.get_default_graph()._original_op(op):
            # pylint: enable=protected-access
              op_wrapper = op
              if has_control_flow:
                op_wrapper = control_flow_ops.MakeWrapper(op)
              in_grads = _AsList(grad_fn(op_wrapper, *out_grads))
              _VerifyGeneratedGradients(in_grads, op)
              if gate_gradients and len(in_grads) > 1:
                in_grads = control_flow_ops.tuple(in_grads)
          logging.vlog(1, "Gradient for '" + op.name + "'")
          logging.vlog(1, "  in  --> %s",
                       ", ".join([x.name for x in out_grads if x]))
          logging.vlog(1, "  out --> %s",
                       ", ".join([x.name for x in in_grads if x]))
        else:
          # If no grad_fn is defined or none of out_grads is available,
          # just propagates a list of None backwards.
          in_grads = [None] * len(op.inputs)
        for t_in, in_grad in zip(op.inputs, in_grads):
          if in_grad:
            _SetGrad(grads, t_in, in_grad)
        if has_control_flow:
          control_flow_ops.ExitGradWhileContext(op)

      # update pending count for the inputs of op.
      for x in op.inputs:
        pending_count[x.op._id] -= 1
        ready = (pending_count[x.op._id] == 0)
        if has_control_flow and not ready:
          ready = (pending_count[x.op._id] > 0 and
                   control_flow_ops.IsLoopSwitch(x.op))
        if ready:
          queue.append(x.op)
      for x in op.control_inputs:
        pending_count[x._id] -= 1
        if pending_count[x._id] is 0:
          queue.append(x)
  return [_GetGrad(grads, x) for x in xs]
开发者ID:bradg19,项目名称:tensor,代码行数:101,代码来源:gradients.py


示例18: MakeNdarray

def MakeNdarray(tensor):
    """Create a numpy ndarray from a tensor.

  Create a numpy ndarray with the same shape and data as the tensor.

  Args:
    tensor: A TensorProto.

  Returns:
    A numpy array with the tensor contents.

  Raises:
    TypeError: if tensor has unsupported type.

  """
    shape = [d.size for d in tensor.tensor_shape.dim]
    num_elements = np.prod(shape)
    tensor_dtype = types.as_dtype(tensor.dtype)
    dtype = tensor_dtype.as_numpy_dtype

    if tensor.tensor_content:
        return np.fromstring(tensor.tensor_content, dtype=dtype).reshape(shape)
    elif tensor_dtype == types.float32:
        if len(tensor.float_val) == 1:
            return np.repeat(np.array(tensor.float_val[0], dtype=dtype), num_elements).reshape(shape)
        else:
            return np.fromiter(tensor.float_val, dtype=dtype).reshape(shape)
    elif tensor_dtype == types.float64:
        if len(tensor.double_val) == 1:
            return np.repeat(np.array(tensor.double_val[0], dtype=dtype), num_elements).reshape(shape)
        else:
            return np.fromiter(tensor.double_val, dtype=dtype).reshape(shape)
    elif tensor_dtype in [
        types.int32,
        types.uint8,
        types.int16,
        types.int8,
        types.qint32,
        types.quint8,
        types.qint8,
        types.bfloat16,
    ]:
        if len(tensor.int_val) == 1:
            return np.repeat(np.array(tensor.int_val[0], dtype=dtype), num_elements).reshape(shape)
        else:
            return np.fromiter(tensor.int_val, dtype=dtype).reshape(shape)
    elif tensor_dtype == types.int64:
        if len(tensor.int64_val) == 1:
            return np.repeat(np.array(tensor.int64_val[0], dtype=dtype), num_elements).reshape(shape)
        else:
            return np.fromiter(tensor.int64_val, dtype=dtype).reshape(shape)
    elif tensor_dtype == types.string:
        if len(tensor.string_val) == 1:
            return np.repeat(np.array(str(tensor.string_val[0]), dtype=dtype), num_elements).reshape(shape)
        else:
            return np.array([str(x) for x in tensor.string_val], dtype=dtype).reshape(shape)
    elif tensor_dtype == types.complex64:
        it = iter(tensor.scomplex_val)
        if len(tensor.scomplex_val) == 2:
            return np.repeat(
                np.array(complex(tensor.scomplex_val[0], tensor.scomplex_val[1]), dtype=dtype), num_elements
            ).reshape(shape)
        else:
            return np.array([complex(x[0], x[1]) for x in zip(it, it)], dtype=dtype).reshape(shape)
    elif tensor_dtype == types.bool:
        if len(tensor.bool_val) == 1:
            return np.repeat(np.array(tensor.bool_val[0], dtype=dtype), num_elements).reshape(shape)
        else:
            return np.fromiter(tensor.bool_val, dtype=dtype).reshape(shape)
    else:
        raise TypeError("Unsupported tensor type: %s" % tensor.dtype)
开发者ID:adeelzaman,项目名称:tensorflow,代码行数:71,代码来源:tensor_util.py


示例19: testAllTypesConvertibleToDType

 def testAllTypesConvertibleToDType(self):
   for datatype_enum in types_pb2.DataType.values():
     if datatype_enum == types_pb2.DT_INVALID:
       continue
     self.assertEqual(
         datatype_enum, types.as_dtype(datatype_enum).as_datatype_enum)
开发者ID:debaratidas1994,项目名称:tensorflow,代码行数:6,代码来源:types_test.py


示例20: testInvalid

 def testInvalid(self):
   with self.assertRaises(TypeError):
     types.DType(types_pb2.DT_INVALID)
   with self.assertRaises(TypeError):
     types.as_dtype(types_pb2.DT_INVALID)
开发者ID:debaratidas1994,项目名称:tensorflow,代码行数:5,代码来源:types_test.py



注:本文中的tensorflow.python.framework.types.as_dtype函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python backend.floatx函数代码示例发布时间:2022-05-27
下一篇:
Python models.send_emails函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap