• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python sparse.deserialize_sparse_tensors函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.data.util.sparse.deserialize_sparse_tensors函数的典型用法代码示例。如果您正苦于以下问题:Python deserialize_sparse_tensors函数的具体用法?Python deserialize_sparse_tensors怎么用?Python deserialize_sparse_tensors使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了deserialize_sparse_tensors函数的19个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: tf_reduce_func

      def tf_reduce_func(*args):
        """A wrapper for Defun that facilitates shape inference."""
        for arg, shape in zip(
            args,
            nest.flatten(
                sparse.as_dense_shapes(self._state_shapes, self._state_classes))
            + nest.flatten(
                sparse.as_dense_shapes(input_dataset.output_shapes,
                                       input_dataset.output_classes))):
          arg.set_shape(shape)

        pivot = len(nest.flatten(self._state_shapes))
        nested_state_args = nest.pack_sequence_as(self._state_types,
                                                  args[:pivot])
        nested_state_args = sparse.deserialize_sparse_tensors(
            nested_state_args, self._state_types, self._state_shapes,
            self._state_classes)
        nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
                                                  args[pivot:])
        nested_input_args = sparse.deserialize_sparse_tensors(
            nested_input_args, input_dataset.output_types,
            input_dataset.output_shapes, input_dataset.output_classes)

        ret = reduce_func(nested_state_args, nested_input_args)

        # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
        # values to tensors.
        ret = nest.pack_sequence_as(ret, [
            sparse_tensor.SparseTensor.from_value(t)
            if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
            for t in nest.flatten(ret)
        ])

        # Extract shape information from the returned values.
        flat_new_state = nest.flatten(ret)
        flat_new_state_shapes.extend([t.get_shape() for t in flat_new_state])

        # Extract and validate type information from the returned values.
        for t, dtype in zip(flat_new_state, nest.flatten(self._state_types)):
          if t.dtype != dtype:
            raise TypeError(
                "The element types for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_types,
                 nest.pack_sequence_as(self._state_types,
                                       [t.dtype for t in flat_new_state])))

        dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

        # Serialize any sparse tensors.
        ret = nest.pack_sequence_as(
            ret,
            [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
        return nest.flatten(ret)
开发者ID:xman,项目名称:tensorflow,代码行数:54,代码来源:grouping.py


示例2: get_next

  def get_next(self, name=None):
    """Returns a nested structure of `tf.Tensor`s containing the next element.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
    self._get_next_call_count += 1
    if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
      warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)

    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types,
                              gen_dataset_ops.iterator_get_next(
                                  self._iterator_resource,
                                  output_types=nest.flatten(
                                      sparse.as_dense_types(
                                          self._output_types,
                                          self._output_classes)),
                                  output_shapes=nest.flatten(
                                      sparse.as_dense_shapes(
                                          self._output_shapes,
                                          self._output_classes)),
                                  name=name)), self._output_types,
        self._output_shapes, self._output_classes)
开发者ID:modkzs,项目名称:tensorflow,代码行数:27,代码来源:iterator_ops.py


示例3: testSerializeDeserialize

 def testSerializeDeserialize(self):
   test_cases = (
       (),
       sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
       sparse_tensor.SparseTensor(
           indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       sparse_tensor.SparseTensor(
           indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()),
       ((), sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
   )
   for expected in test_cases:
     classes = sparse.get_classes(expected)
     shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
                                 classes)
     types = nest.map_structure(lambda _: dtypes.int32, classes)
     actual = sparse.deserialize_sparse_tensors(
         sparse.serialize_sparse_tensors(expected), types, shapes,
         sparse.get_classes(expected))
     nest.assert_same_structure(expected, actual)
     for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
       self.assertSparseValuesEqual(a, e)
开发者ID:abidrahmank,项目名称:tensorflow,代码行数:27,代码来源:sparse_test.py


示例4: get_next

  def get_next(self, name=None):
    """See `tf.data.Iterator.get_next`."""
    self._get_next_call_count += 1
    if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
      warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)

    flat_result = []
    # TODO(priyag): This will fail if the input size (typically number of
    # batches) is not divisible by number of devices.
    # How do we handle that more gracefully / let the user know?
    for buffer_resource in self._buffering_resources:
      flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
          buffer_resource,
          output_types=data_nest.flatten(sparse.as_dense_types(
              self.output_types, self.output_classes)), name=name)

      ret = sparse.deserialize_sparse_tensors(
          data_nest.pack_sequence_as(self.output_types, flat_ret),
          self.output_types, self.output_shapes, self.output_classes)

      for tensor, shape in zip(
          data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
        if isinstance(tensor, ops.Tensor):
          tensor.set_shape(shape)
      flat_result.append(ret)

    return nest.pack_sequence_as(self._devices, flat_result)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:27,代码来源:prefetching_ops_v2.py


示例5: tf_map_func

    def tf_map_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                            input_dataset.output_classes)
      for arg, shape in zip(args, nest.flatten(dense_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, input_dataset.output_types, input_dataset.output_shapes,
          input_dataset.output_classes)
      if dataset_ops._should_unpack_args(nested_args):  # pylint: disable=protected-access
        dataset = map_func(*nested_args)
      else:
        dataset = map_func(nested_args)

      if not isinstance(dataset, dataset_ops.Dataset):
        raise TypeError("`map_func` must return a `Dataset` object.")

      self._output_classes = dataset.output_classes
      self._output_types = dataset.output_types
      self._output_shapes = dataset.output_shapes

      return dataset._as_variant_tensor()  # pylint: disable=protected-access
开发者ID:AnddyWang,项目名称:tensorflow,代码行数:25,代码来源:interleave_ops.py


示例6: tf_finalize_func

    def tf_finalize_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      for arg, shape in zip(
          args,
          nest.flatten(
              sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(self._state_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, self._state_types, self._state_shapes,
          self._state_classes)

      ret = finalize_func(nested_args)

      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._output_classes = sparse.get_classes(ret)
      self._output_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._output_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)
开发者ID:xman,项目名称:tensorflow,代码行数:35,代码来源:grouping.py


示例7: tf_key_func

    def tf_key_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
                                            input_dataset.output_classes)
      for arg, shape in zip(args, nest.flatten(dense_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, input_dataset.output_types, input_dataset.output_shapes,
          input_dataset.output_classes)
      # pylint: disable=protected-access
      if dataset_ops._should_unpack_args(nested_args):
        ret = key_func(*nested_args)
      # pylint: enable=protected-access
      else:
        ret = key_func(nested_args)
      ret = ops.convert_to_tensor(ret)
      if ret.dtype != dtypes.int64 or ret.get_shape() != tensor_shape.scalar():
        raise ValueError(
            "`key_func` must return a single tf.int64 tensor. "
            "Got type=%s and shape=%s" % (ret.dtype, ret.get_shape()))
      dataset_ops._warn_if_collections("tf.contrib.data.group_by_reducer()")  # pylint: disable=protected-access
      return ret
开发者ID:xman,项目名称:tensorflow,代码行数:25,代码来源:grouping.py


示例8: tf_finalize_func

    def tf_finalize_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      for arg, shape in zip(
          args,
          nest.flatten(
              sparse.as_dense_shapes(self._state_shapes, self._state_classes))):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(self._state_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, self._state_types, self._state_shapes,
          self._state_classes)

      ret = finalize_func(nested_args)

      # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
      # values to tensors.
      ret = nest.pack_sequence_as(ret, [
          sparse_tensor.SparseTensor.from_value(t)
          if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
          for t in nest.flatten(ret)
      ])

      self._output_classes = sparse.get_classes(ret)
      self._output_shapes = nest.pack_sequence_as(
          ret, [t.get_shape() for t in nest.flatten(ret)])
      self._output_types = nest.pack_sequence_as(
          ret, [t.dtype for t in nest.flatten(ret)])

      # Serialize any sparse tensors.
      ret = nest.pack_sequence_as(
          ret, [t for t in nest.flatten(sparse.serialize_sparse_tensors(ret))])
      return nest.flatten(ret)
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:33,代码来源:grouping.py


示例9: _next_internal

  def _next_internal(self):
    """Returns a nested structure of `tf.Tensor`s containing the next element.
    """
    with ops.device(self._device):
      if self._buffer_resource_handle is not None:
        ret = prefetching_ops.function_buffering_resource_get_next(
            function_buffer_resource=self._buffer_resource_handle,
            output_types=self._flat_output_types)
      else:
        # TODO(ashankar): Consider removing this ops.device() contextmanager
        # and instead mimic ops placement in graphs: Operations on resource
        # handles execute on the same device as where the resource is placed.
        # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
        # because in eager mode this code will run synchronously on the calling
        # thread. Therefore we do not need to make a defensive context switch
        # to a background thread, and can achieve a small constant performance
        # boost by invoking the iterator synchronously.
        ret = gen_dataset_ops.iterator_get_next_sync(
            self._resource,
            output_types=self._flat_output_types,
            output_shapes=self._flat_output_shapes)

    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types, ret), self._output_types,
        self._output_shapes, self._output_classes)
开发者ID:DILASSS,项目名称:tensorflow,代码行数:25,代码来源:datasets.py


示例10: _next_internal

 def _next_internal(self):
   """Returns a nested structure of `tf.Tensor`s containing the next element.
   """
   if self._buffer_resource_handle is not None:
     with ops.device(self._device):
       ret = prefetching_ops.function_buffering_resource_get_next(
           function_buffer_resource=self._buffer_resource_handle,
           output_types=self._flat_output_types)
     return sparse.deserialize_sparse_tensors(
         nest.pack_sequence_as(self._output_types, ret), self._output_types,
         self._output_shapes, self._output_classes)
   else:
     return super(Iterator, self)._next_internal()
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:13,代码来源:datasets.py


示例11: get_single_element

def get_single_element(dataset):
  """Returns the single element in `dataset` as a nested structure of tensors.

  This function enables you to use a @{tf.data.Dataset} in a stateless
  "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}.
  This can be useful when your preprocessing transformations are expressed
  as a `Dataset`, and you want to use the transformation at serving time.
  For example:

  ```python
  input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE])

  def preprocessing_fn(input_str):
    # ...
    return image, label

  dataset = (tf.data.Dataset.from_tensor_slices(input_batch)
             .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
             .batch(BATCH_SIZE))

  image_batch, label_batch = tf.contrib.data.get_single_element(dataset)
  ```

  Args:
    dataset: A @{tf.data.Dataset} object containing a single element.

  Returns:
    A nested structure of @{tf.Tensor} objects, corresponding to the single
    element of `dataset`.

  Raises:
    TypeError: if `dataset` is not a `tf.data.Dataset` object.
    InvalidArgumentError (at runtime): if `dataset` does not contain exactly
      one element.
  """
  if not isinstance(dataset, dataset_ops.Dataset):
    raise TypeError("`dataset` must be a `tf.data.Dataset` object.")

  nested_ret = nest.pack_sequence_as(
      dataset.output_types, gen_dataset_ops.dataset_to_single_element(
          dataset._as_variant_tensor(),  # pylint: disable=protected-access
          output_types=nest.flatten(sparse.as_dense_types(
              dataset.output_types, dataset.output_classes)),
          output_shapes=nest.flatten(sparse.as_dense_shapes(
              dataset.output_shapes, dataset.output_classes))))
  return sparse.deserialize_sparse_tensors(
      nested_ret, dataset.output_types, dataset.output_shapes,
      dataset.output_classes)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:48,代码来源:get_single_element.py


示例12: get_value

 def get_value(self, name=None):
   # TODO(b/110122868): Consolidate the restructuring logic with similar logic
   # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
   with ops.name_scope(name, "OptionalGetValue",
                       [self._variant_tensor]) as scope:
     return sparse.deserialize_sparse_tensors(
         nest.pack_sequence_as(
             self._output_types,
             gen_dataset_ops.optional_get_value(
                 self._variant_tensor,
                 name=scope,
                 output_types=nest.flatten(
                     sparse.as_dense_types(self._output_types,
                                           self._output_classes)),
                 output_shapes=nest.flatten(
                     sparse.as_dense_shapes(self._output_shapes,
                                            self._output_classes)))),
         self._output_types, self._output_shapes, self._output_classes)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:18,代码来源:optional_ops.py


示例13: tf_key_func

    def tf_key_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      nested_args = sparse.deserialize_sparse_tensors(
          nested_args, input_dataset.output_types)
      # pylint: disable=protected-access
      if dataset_ops._should_unpack_args(nested_args):
        ret = key_func(*nested_args)
      # pylint: enable=protected-access
      else:
        ret = key_func(nested_args)
      ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
      if ret.dtype != dtypes.int64:
        raise ValueError("`key_func` must return a single tf.int64 tensor.")
      return ret
开发者ID:SylChan,项目名称:tensorflow,代码行数:19,代码来源:grouping.py


示例14: get_next

  def get_next(self, name=None):
    """Returns a nested structure of `tf.Tensor`s containing the next element.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types,
                              gen_dataset_ops.iterator_get_next(
                                  self._iterator_resource,
                                  output_types=nest.flatten(
                                      sparse.unwrap_sparse_types(
                                          self._output_types)),
                                  output_shapes=nest.flatten(
                                      self._output_shapes),
                                  name=name)), self._output_types)
开发者ID:SylChan,项目名称:tensorflow,代码行数:19,代码来源:iterator_ops.py


示例15: _next_internal

  def _next_internal(self):
    """Returns a nested structure of `tf.Tensor`s containing the next element.
    """
    with ops.device(self._device):
      if self._buffer_resource_handle is not None:
        ret = prefetching_ops.function_buffering_resource_get_next(
            function_buffer_resource=self._buffer_resource_handle,
            output_types=self._flat_output_types)
      else:
        # TODO(ashankar): Consider removing this ops.device() contextmanager
        # and instead mimic ops placement in graphs: Operations on resource
        # handles execute on the same device as where the resource is placed.
        ret = gen_dataset_ops.iterator_get_next(
            self._resource,
            output_types=self._flat_output_types,
            output_shapes=self._flat_output_shapes)

    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types, ret), self._output_types,
        self._output_shapes, self._output_classes)
开发者ID:Fair-Child,项目名称:tensorflow,代码行数:20,代码来源:datasets.py


示例16: testSerializeDeserialize

 def testSerializeDeserialize(self):
   test_cases = (
       (),
       sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
       sparse_tensor.SparseTensor(
           indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
       sparse_tensor.SparseTensor(
           indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
       (sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()),
       ((), sparse_tensor.SparseTensor(
           indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
   )
   for expected in test_cases:
     actual = sparse.deserialize_sparse_tensors(
         sparse.serialize_sparse_tensors(expected),
         sparse.get_sparse_types(expected))
     nest.assert_same_structure(expected, actual)
     for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
       self.assertSparseValuesEqual(a, e)
开发者ID:SylChan,项目名称:tensorflow,代码行数:23,代码来源:sparse_test.py


示例17: _next_internal

  def _next_internal(self):
    """Returns a nested structure of `tf.Tensor`s containing the next element.
    """
    # This runs in sync mode as iterators use an error status to communicate
    # that there is no more data to iterate over.
    # TODO(b/77291417): Fix
    with context.execution_mode(context.SYNC):
      with ops.device(self._device):
        # TODO(ashankar): Consider removing this ops.device() contextmanager
        # and instead mimic ops placement in graphs: Operations on resource
        # handles execute on the same device as where the resource is placed.
        # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next`
        # because in eager mode this code will run synchronously on the calling
        # thread. Therefore we do not need to make a defensive context switch
        # to a background thread, and can achieve a small constant performance
        # boost by invoking the iterator synchronously.
        ret = gen_dataset_ops.iterator_get_next_sync(
            self._resource,
            output_types=self._flat_output_types,
            output_shapes=self._flat_output_shapes)

      return sparse.deserialize_sparse_tensors(
          nest.pack_sequence_as(self._output_types, ret), self._output_types,
          self._output_shapes, self._output_classes)
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:24,代码来源:iterator_ops.py


示例18: get_next

  def get_next(self, name=None):
    """Returns a nested structure of `tf.Tensor`s representing the next element.

    In graph mode, you should typically call this method *once* and use its
    result as the input to another computation. A typical loop will then call
    @{tf.Session.run} on the result of that computation. The loop will terminate
    when the `Iterator.get_next()` operation raises
    @{tf.errors.OutOfRangeError}. The following skeleton shows how to use
    this method when building a training loop:

    ```python
    dataset = ...  # A `tf.data.Dataset` object.
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    # Build a TensorFlow graph that does something with each element.
    loss = model_function(next_element)
    optimizer = ...  # A `tf.train.Optimizer` object.
    train_op = optimizer.minimize(loss)

    with tf.Session() as sess:
      try:
        while True:
          sess.run(train_op)
      except tf.errors.OutOfRangeError:
        pass
    ```

    NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g.
    when you are distributing different elements to multiple devices in a single
    step. However, a common pitfall arises when users call `Iterator.get_next()`
    in each iteration of their training loop. `Iterator.get_next()` adds ops to
    the graph, and executing each op allocates resources (including threads); as
    a consequence, invoking it in every iteration of a training loop causes
    slowdown and eventual resource exhaustion. To guard against this outcome, we
    log a warning when the number of uses crosses a fixed threshold of
    suspiciousness.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
    self._get_next_call_count += 1
    if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
      warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)

    return sparse.deserialize_sparse_tensors(
        nest.pack_sequence_as(self._output_types,
                              gen_dataset_ops.iterator_get_next(
                                  self._iterator_resource,
                                  output_types=nest.flatten(
                                      sparse.as_dense_types(
                                          self._output_types,
                                          self._output_classes)),
                                  output_shapes=nest.flatten(
                                      sparse.as_dense_shapes(
                                          self._output_shapes,
                                          self._output_classes)),
                                  name=name)), self._output_types,
        self._output_shapes, self._output_classes)
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:62,代码来源:iterator_ops.py


示例19: tf_scan_func

      def tf_scan_func(*args):
        """A wrapper for Defun that facilitates shape inference."""
        # Pass in shape information from the state and input_dataset.
        for arg, shape in zip(
            args,
            nest.flatten(
                sparse.as_dense_shapes(self._state_shapes, self._state_classes))
            + nest.flatten(
                sparse.as_dense_shapes(input_dataset.output_shapes,
                                       input_dataset.output_classes))):
          arg.set_shape(shape)

        pivot = len(nest.flatten(self._state_shapes))
        print(self._state_classes)
        nested_state_args = nest.pack_sequence_as(self._state_types,
                                                  args[:pivot])
        nested_state_args = sparse.deserialize_sparse_tensors(
            nested_state_args, self._state_types, self._state_shapes,
            self._state_classes)
        print(input_dataset.output_classes)
        nested_input_args = nest.pack_sequence_as(input_dataset.output_types,
                                                  args[pivot:])
        nested_input_args = sparse.deserialize_sparse_tensors(
            nested_input_args, input_dataset.output_types,
            input_dataset.output_shapes, input_dataset.output_classes)

        ret = scan_func(nested_state_args, nested_input_args)
        if not isinstance(ret, collections.Sequence) or len(ret) != 2:
          raise TypeError("The scan function must return a pair comprising the "
                          "new state and the output value.")

        # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
        # values to tensors.
        ret = nest.pack_sequence_as(ret, [
            sparse_tensor.SparseTensor.from_value(t)
            if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(t)
            for t in nest.flatten(ret)
        ])
        new_state, output_value = ret

        # Extract and validate class information from the returned values.
        for t, clazz in zip(
            nest.flatten(new_state), nest.flatten(self._state_classes)):
          if not isinstance(t, clazz):
            raise TypeError(
                "The element classes for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_classes,
                 nest.pack_sequence_as(
                     self._state_types,
                     [type(t) for t in nest.flatten(new_state)])))
        self._output_classes = sparse.get_classes(output_value)

        # Extract shape information from the returned values.
        flat_new_state_shapes.extend(
            [t.get_shape() for t in nest.flatten(new_state)])
        self._output_shapes = nest.pack_sequence_as(
            output_value, [t.get_shape() for t in nest.flatten(output_value)])

        # Extract and validate type information from the returned values.
        for t, dtype in zip(
            nest.flatten(new_state), nest.flatten(self._state_types)):
          if t.dtype != dtype:
            raise TypeError(
                "The element types for the new state must match the initial "
                "state. Expected %s; got %s." %
                (self._state_types,
                 nest.pack_sequence_as(
                     self._state_types,
                     [t.dtype for t in nest.flatten(new_state)])))
        self._output_types = nest.pack_sequence_as(
            output_value, [t.dtype for t in nest.flatten(output_value)])

        dataset_ops._warn_if_collections("tf.contrib.data.scan()")  # pylint: disable=protected-access

        # Serialize any sparse tensors.
        new_state = nest.pack_sequence_as(new_state, [
            t for t in nest.flatten(sparse.serialize_sparse_tensors(new_state))
        ])
        output_value = nest.pack_sequence_as(output_value, [
            t for t in nest.flatten(
                sparse.serialize_sparse_tensors(output_value))
        ])
        return nest.flatten(new_state) + nest.flatten(output_value)
开发者ID:xman,项目名称:tensorflow,代码行数:84,代码来源:scan_ops.py



注:本文中的tensorflow.python.data.util.sparse.deserialize_sparse_tensors函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python cli_shared.get_run_short_description函数代码示例发布时间:2022-05-27
下一篇:
Python sparse.as_dense_types函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap