• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python nest.map_structure_up_to函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.data.util.nest.map_structure_up_to函数的典型用法代码示例。如果您正苦于以下问题:Python map_structure_up_to函数的具体用法?Python map_structure_up_to怎么用?Python map_structure_up_to使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了map_structure_up_to函数的12个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testMapStructureUpTo

  def testMapStructureUpTo(self):
    ab_tuple = collections.namedtuple("ab_tuple", "a, b")
    op_tuple = collections.namedtuple("op_tuple", "add, mul")
    inp_val = ab_tuple(a=2, b=3)
    inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
    out = nest.map_structure_up_to(
        inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops)
    self.assertEqual(out.a, 6)
    self.assertEqual(out.b, 15)

    data_list = ((2, 4, 6, 8), ((1, 3, 5, 7, 9), (3, 5, 7)))
    name_list = ("evens", ("odds", "primes"))
    out = nest.map_structure_up_to(
        name_list, lambda name, sec: "first_{}_{}".format(len(sec), name),
        name_list, data_list)
    self.assertEqual(out, ("first_4_evens", ("first_5_odds", "first_3_primes")))
开发者ID:abidrahmank,项目名称:tensorflow,代码行数:16,代码来源:nest_test.py


示例2: tf_shapes

 def tf_shapes(self):
     """
     :return: a dictionary of sampler output tensor shapes
     """
     output_shapes = nest.map_structure_up_to(
         self.tf_dtypes, tf.TensorShape, self.shapes)
     return output_shapes
开发者ID:fepegar,项目名称:NiftyNet,代码行数:7,代码来源:image_window.py


示例3: from_string_handle

  def from_string_handle(string_handle, output_types, output_shapes=None):
    """Creates a new, uninitialized `Iterator` based on the given handle.

    This method allows you to define a "feedable" iterator where you can choose
    between concrete iterators by feeding a value in a @{tf.Session.run} call.
    In that case, `string_handle` would a @{tf.placeholder}, and you would feed
    it with the value of @{tf.data.Iterator.string_handle} in each step.

    For example, if you had two iterators that marked the current position in
    a training dataset and a test dataset, you could choose which to use in
    each step as follows:

    ```python
    train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    train_iterator_handle = sess.run(train_iterator.string_handle())

    test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    test_iterator_handle = sess.run(test_iterator.string_handle())

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_iterator.output_types)

    next_element = iterator.get_next()
    loss = f(next_element)

    train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
    test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
    ```

    Args:
      string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates
        to a handle produced by the `Iterator.string_handle()` method.
      output_types: A nested structure of `tf.DType` (or `tf.data.SparseType`)
        objects corresponding to each `tf.Tensor` (or `tf.SparseTensor`)
        component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.

    Returns:
      An `Iterator`.
    """
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    nest.assert_same_structure(output_types, output_shapes)
    string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
    iterator_resource = gen_dataset_ops.iterator_from_string_handle(
        string_handle,
        output_types=nest.flatten(sparse.unwrap_sparse_types(output_types)),
        output_shapes=nest.flatten(output_shapes))
    return Iterator(iterator_resource, None, output_types, output_shapes)
开发者ID:SylChan,项目名称:tensorflow,代码行数:57,代码来源:iterator_ops.py


示例4: set_spatial_shape

    def set_spatial_shape(self, spatial_window, source_names=None):
        """
        Set all spatial window of the window.

        spatial_window should be a dictionary of window sizes tuples
        or single window size tuple.  In the latter case the size
        will be used by all output image windows.

        :param spatial_window: tuple of integers specifying new shape
        :param source_names: list/dictionary of input source names
        :return:
        """
        win_sizes = copy.deepcopy(spatial_window)
        if isinstance(spatial_window, dict):
            for name in list(spatial_window):
                window_size = spatial_window[name]
                if isinstance(window_size,
                              (ParserNamespace, argparse.Namespace)):
                    window_size = vars(window_size)
                if not isinstance(window_size, dict):
                    win_sizes[name] = tuple(window_size)
                elif 'spatial_window_size' in window_size:
                    win_sizes[name] = tuple(
                        window_size['spatial_window_size'])
                else:
                    raise ValueError(
                        'window_sizes should be a nested dictionary')
        elif isinstance(spatial_window, (list, tuple)):
            # list or tuple of single window sizes
            win_sizes = {name: spatial_window for name in list(self._dtypes)}

        # complete window shapes based on user input and input_image sizes
        if source_names:
            spatial_shapes = _read_window_sizes(source_names, win_sizes)
        else:
            try:
                spatial_shapes = {}
                for name in list(self._dtypes):
                    spatial_shapes[name] = \
                        tuple(int(win_size) for win_size in win_sizes[name])
            except ValueError:
                tf.logging.fatal("spatial window should be an array of int")
                raise

        spatial_shapes = nest.map_structure_up_to(
            self._dtypes, tuple, spatial_shapes)

        self._shapes = {
            name: _complete_partial_window_sizes(spatial_shapes[name],
                                                 self._shapes[name])
            for name in list(self._shapes)}

        # update based on the latest spatial shapes
        self.has_dynamic_shapes = self._check_dynamic_shapes()
        if self._placeholders_dict is not None:
            self._update_placeholders_dict(n_samples=self.n_samples)
开发者ID:fepegar,项目名称:NiftyNet,代码行数:56,代码来源:image_window.py


示例5: __init__

  def __init__(self, dataset, output_types, output_shapes=None):
    """Creates a new dataset with the given output types and shapes.

    The given `dataset` must have a structure that is convertible:
    * `dataset.output_types` must be the same as `output_types` module nesting.
    * Each shape in `dataset.output_shapes` must be compatible with each shape
      in `output_shapes` (if given).

    Note: This helper permits "unsafe casts" for shapes, equivalent to using
    `tf.Tensor.set_shape()` where domain-specific knowledge is available.

    Args:
      dataset: A `Dataset` object.
      output_types: A nested structure of `tf.DType` objects.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
        If omitted, the shapes will be inherited from `dataset`.

    Raises:
      ValueError: If either `output_types` or `output_shapes` is not compatible
        with the structure of `dataset`.
    """
    super(_RestructuredDataset, self).__init__()
    self._dataset = dataset

    # Validate that the types are compatible.
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    flat_original_types = nest.flatten(dataset.output_types)
    flat_new_types = nest.flatten(output_types)
    if flat_original_types != flat_new_types:
      raise ValueError(
          "Dataset with output types %r cannot be restructured to have output "
          "types %r" % (dataset.output_types, output_types))

    self._output_types = output_types

    if output_shapes is None:
      # Inherit shapes from the original `dataset`.
      self._output_shapes = nest.pack_sequence_as(output_types,
                                                  nest.flatten(
                                                      dataset.output_shapes))
    else:
      # Validate that the shapes are compatible.
      nest.assert_same_structure(output_types, output_shapes)
      flat_original_shapes = nest.flatten(dataset.output_shapes)
      flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)

      for original_shape, new_shape in zip(flat_original_shapes,
                                           flat_new_shapes):
        if not original_shape.is_compatible_with(new_shape):
          raise ValueError(
              "Dataset with output shapes %r cannot be restructured to have "
              "incompatible output shapes %r" % (dataset.output_shapes,
                                                 output_shapes))
      self._output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:55,代码来源:batching.py


示例6: __init__

 def __init__(self, input_dataset, batch_size, padded_shapes, padding_values):
   """Initialize `PrependFromQueueAndPaddedBatchDataset`."""
   super(_PrependFromQueueAndPaddedBatchDataset, self).__init__()
   if sparse.any_sparse(input_dataset.output_classes):
     raise TypeError(
         "Batching of padded sparse tensors is not currently supported")
   self._input_dataset = input_dataset
   self._batch_size = ops.convert_to_tensor(
       batch_size, dtype=dtypes.int64, name="batch_size")
   # pylint: disable=protected-access
   if padded_shapes is None:
     self._padded_shapes = nest.map_structure(
         dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes)
   else:
     self._padded_shapes = nest.map_structure_up_to(
         input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor,
         padded_shapes)
   padding_values = (
       padding_values if padding_values is not None else
       dataset_ops._default_padding(input_dataset))
   self._padding_values = nest.map_structure_up_to(
       input_dataset.output_shapes, dataset_ops._padding_value_to_tensor,
       padding_values, input_dataset.output_types)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:23,代码来源:tensor_queue_dataset.py


示例7: __init__

 def __init__(self, variant_tensor, output_shapes, output_types,
              output_classes):
   # TODO(b/110122868): Consolidate the structure validation logic with the
   # similar logic in `Iterator.from_structure()` and
   # `Dataset.from_generator()`.
   output_types = nest.map_structure(dtypes.as_dtype, output_types)
   output_shapes = nest.map_structure_up_to(
       output_types, tensor_shape.as_shape, output_shapes)
   nest.assert_same_structure(output_types, output_shapes)
   nest.assert_same_structure(output_types, output_classes)
   self._variant_tensor = variant_tensor
   self._output_shapes = output_shapes
   self._output_types = output_types
   self._output_classes = output_classes
开发者ID:AnishShah,项目名称:tensorflow,代码行数:14,代码来源:optional_ops.py


示例8: from_data_reader_properties

    def from_data_reader_properties(cls,
                                    source_names,
                                    image_shapes,
                                    image_dtypes,
                                    window_sizes=None,
                                    allow_dynamic=False):
        """
        Create a window instance with input data properties
        each property is grouped into dict, with pairs of
        image_name: data_value. Some input images is a
        concatenated data array from multiple data sources.
        example of input::

            source_names={
                'image': (u'modality1', u'modality2'),
                'label': (u'modality3',)},
            image_shapes={
                'image': (192, 160, 192, 1, 2),
                'label': (192, 160, 192, 1, 1)},
            image_dtypes={
                'image': tf.float32,
                'label': tf.float32},
            window_sizes={
                'image': (10, 10, 2),
                'label': (10, 10, 2)}

        the ``window_sizes`` can also be::

            window_sizes={
                'modality1': (10, 10, 2),
                'modality3': (10, 10, 2)}

        or using a nested dictionary with 'spatial_window_size' (deprecating)::

            window_sizes={
                'modality1': {'spatial_window_size': (10, 10, 2)},
                'modality2': {'spatial_window_size': (10, 10, 2)},
                'modality3': {'spatial_window_size': (5, 5, 1)}}

        see ``niftynet.io.ImageReader`` for more details.

        :param source_names: input image names
        :param image_shapes: tuple of image window shapes
        :param image_dtypes: tuple of image window data types
        :param window_sizes: window sizes for the image image
        :param allow_dynamic: if True, window_sizes negative or 0 indicates
            dynamic window sizes; . Otherwise the dynamic sizes will be fixed
            as the image shapes; this assumes the same image size across the
            dataset.
        :return: an ImageWindow instance
        """
        try:
            image_shapes = nest.map_structure_up_to(
                image_dtypes, tuple, image_shapes)
        except KeyError:
            tf.logging.fatal('window_sizes wrong format %s', window_sizes)
            raise
        # create ImageWindow instance
        window_instance = cls(shapes=image_shapes, dtypes=image_dtypes)

        if not window_sizes:
            # image window sizes not specified, defaulting to image sizes.
            return window_instance

        window_instance.set_spatial_shape(window_sizes, source_names)
        if not allow_dynamic:
            full_shape = window_instance.match_image_shapes(image_shapes)
            window_instance.set_spatial_shape(full_shape)
        return window_instance
开发者ID:fepegar,项目名称:NiftyNet,代码行数:69,代码来源:image_window.py


示例9: from_structure

  def from_structure(output_types,
                     output_shapes=None,
                     shared_name=None,
                     output_classes=None):
    """Creates a new, uninitialized `Iterator` with the given structure.

    This iterator-constructing method can be used to create an iterator that
    is reusable with many different datasets.

    The returned iterator is not bound to a particular dataset, and it has
    no `initializer`. To initialize the iterator, run the operation returned by
    `Iterator.make_initializer(dataset)`.

    The following is an example

    ```python
    iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))

    dataset_range = Dataset.range(10)
    range_initializer = iterator.make_initializer(dataset_range)

    dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
    evens_initializer = iterator.make_initializer(dataset_evens)

    # Define a model based on the iterator; in this example, the model_fn
    # is expected to take scalar tf.int64 Tensors as input (see
    # the definition of 'iterator' above).
    prediction, loss = model_fn(iterator.get_next())

    # Train for `num_epochs`, where for each epoch, we first iterate over
    # dataset_range, and then iterate over dataset_evens.
    for _ in range(num_epochs):
      # Initialize the iterator to `dataset_range`
      sess.run(range_initializer)
      while True:
        try:
          pred, loss_val = sess.run([prediction, loss])
        except tf.errors.OutOfRangeError:
          break

      # Initialize the iterator to `dataset_evens`
      sess.run(evens_initializer)
      while True:
        try:
          pred, loss_val = sess.run([prediction, loss])
        except tf.errors.OutOfRangeError:
          break
    ```

    Args:
      output_types: A nested structure of `tf.DType` objects corresponding to
        each component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.
      shared_name: (Optional.) If non-empty, this iterator will be shared under
        the given name across multiple sessions that share the same devices
        (e.g. when using a remote server).
      output_classes: (Optional.) A nested structure of Python `type` objects
        corresponding to each component of an element of this iterator. If
        omitted, each component is assumed to be of type `tf.Tensor`.

    Returns:
      An `Iterator`.

    Raises:
      TypeError: If the structures of `output_shapes` and `output_types` are
        not the same.
    """
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    if output_classes is None:
      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
    nest.assert_same_structure(output_types, output_shapes)
    if shared_name is None:
      shared_name = ""
    iterator_resource = gen_dataset_ops.iterator(
        container="",
        shared_name=shared_name,
        output_types=nest.flatten(
            sparse.as_dense_types(output_types, output_classes)),
        output_shapes=nest.flatten(
            sparse.as_dense_shapes(output_shapes, output_classes)))
    return Iterator(iterator_resource, None, output_types, output_shapes,
                    output_classes)
开发者ID:modkzs,项目名称:tensorflow,代码行数:90,代码来源:iterator_ops.py


示例10: from_string_handle

  def from_string_handle(string_handle,
                         output_types,
                         output_shapes=None,
                         output_classes=None):
    """Creates a new, uninitialized `Iterator` based on the given handle.

    This method allows you to define a "feedable" iterator where you can choose
    between concrete iterators by feeding a value in a `tf.Session.run` call.
    In that case, `string_handle` would be a `tf.placeholder`, and you would
    feed it with the value of `tf.data.Iterator.string_handle` in each step.

    For example, if you had two iterators that marked the current position in
    a training dataset and a test dataset, you could choose which to use in
    each step as follows:

    ```python
    train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    train_iterator_handle = sess.run(train_iterator.string_handle())

    test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
    test_iterator_handle = sess.run(test_iterator.string_handle())

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, train_iterator.output_types)

    next_element = iterator.get_next()
    loss = f(next_element)

    train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
    test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
    ```

    Args:
      string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates
        to a handle produced by the `Iterator.string_handle()` method.
      output_types: A nested structure of `tf.DType` objects corresponding to
        each component of an element of this dataset.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
        corresponding to each component of an element of this dataset. If
        omitted, each component will have an unconstrainted shape.
      output_classes: (Optional.) A nested structure of Python `type` objects
        corresponding to each component of an element of this iterator. If
        omitted, each component is assumed to be of type `tf.Tensor`.

    Returns:
      An `Iterator`.
    """
    output_types = nest.map_structure(dtypes.as_dtype, output_types)
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)
    if output_classes is None:
      output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
    nest.assert_same_structure(output_types, output_shapes)
    output_structure = structure_lib.convert_legacy_structure(
        output_types, output_shapes, output_classes)
    string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
    # pylint: disable=protected-access
    if compat.forward_compatible(2018, 8, 3):
      if _device_stack_is_empty():
        with ops.device("/cpu:0"):
          iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
              string_handle,
              output_types=output_structure._flat_types,
              output_shapes=output_structure._flat_shapes)
      else:
        iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
            string_handle,
            output_types=output_structure._flat_types,
            output_shapes=output_structure._flat_shapes)
    else:
      iterator_resource = gen_dataset_ops.iterator_from_string_handle(
          string_handle,
          output_types=output_structure._flat_types,
          output_shapes=output_structure._flat_shapes)
    # pylint: enable=protected-access
    return Iterator(iterator_resource, None, output_types, output_shapes,
                    output_classes)
开发者ID:perfmjs,项目名称:tensorflow,代码行数:82,代码来源:iterator_ops.py


示例11: from_generator

  def from_generator(generator, output_types, output_shapes=None):
    """Creates a `Dataset` whose elements are generated by `generator`.

    The `generator` argument must be a callable object that returns
    an object that support the `iter()` protocol (e.g. a generator function).
    The elements generated by `generator` must be compatible with the given
    `output_types` and (optional) `output_shapes` arguments.

    For example:

    ```python
    import itertools

    def gen():
      for i in itertools.count(1):
        yield (i, [1] * i)

    ds = Dataset.from_generator(
        gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None])))
    value = ds.make_one_shot_iterator().get_next()

    sess.run(value)  # (1, array([1]))
    sess.run(value)  # (2, array([1, 1]))
    ```

    Args:
      generator: A callable object that takes no arguments and returns an
        object that supports the `iter()` protocol.
      output_types: A nested structure of `tf.DType` objects corresponding to
        each component of an element yielded by `generator`.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape`
        objects corresponding to each component of an element yielded by
        `generator`.

    Returns:
      A `Dataset`.
    """
    if not callable(generator):
      raise TypeError("`generator` must be callable.")
    if output_shapes is None:
      output_shapes = nest.map_structure(
          lambda _: tensor_shape.TensorShape(None), output_types)
    else:
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)

    flattened_types = nest.flatten(output_types)
    flattened_shapes = nest.flatten(output_shapes)

    generator_state = dataset_ops.Dataset._GeneratorState(generator)

    def get_iterator_id_map_fn(unused_dummy):
      """Creates a unique `iterator_id` for each pass over the dataset.

      The "iterator_id" disambiguates between multiple concurrently
      existing iterators.

      Args:
        unused_dummy: Ignored value.

      Returns:
        A `tf.int64` tensor whose value uniquely identifies an iterator in
        `generator_state`.
      """
      return script_ops.py_func(
          generator_state.get_next_id, [], dtypes.int64, stateful=True)

    def generator_map_fn(iterator_id_t):
      """Generates the next element from iterator with ID `iterator_id_t`.

      We map this function across an infinite repetition of the
      `iterator_id_t`, and raise `StopIteration` to terminate the iteration.

      Args:
        iterator_id_t: A `tf.int64` tensor whose value uniquely identifies
          the iterator in `generator_state` from which to generate an element.

      Returns:
        A nested structure of tensors representing an element from the iterator.
      """

      def generator_py_func(iterator_id):
        """A `py_func` that will be called to invoke the iterator."""
        try:
          values = next(generator_state.get_iterator(iterator_id))
        except StopIteration:
          generator_state.iterator_completed(iterator_id)
          raise StopIteration("Iteration finished.")

        # Use the same _convert function from the py_func() implementation to
        # convert the returned values to arrays early, so that we can inspect
        # their values.
        # pylint: disable=protected-access
        ret_arrays = [
            script_ops.FuncRegistry._convert(ret, dtype=dtype.as_numpy_dtype)
            for ret, dtype in zip(nest.flatten_up_to(output_types, values),
                                  flattened_types)
        ]
        # pylint: enable=protected-access

#.........这里部分代码省略.........
开发者ID:Mazecreator,项目名称:tensorflow,代码行数:101,代码来源:dataset_ops.py


示例12: __init__

  def __init__(self,
               dataset,
               output_types,
               output_shapes=None,
               output_classes=None,
               allow_unsafe_cast=False):
    """Creates a new dataset with the given output types and shapes.

    The given `dataset` must have a structure that is convertible:
    * `dataset.output_types` must be the same as `output_types` module nesting.
    * Each shape in `dataset.output_shapes` must be compatible with each shape
      in `output_shapes` (if given).

    Note: This helper permits "unsafe casts" for shapes, equivalent to using
    `tf.Tensor.set_shape()` where domain-specific knowledge is available.

    Args:
      dataset: A `Dataset` object.
      output_types: A nested structure of `tf.DType` objects.
      output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
        If omitted, the shapes will be inherited from `dataset`.
      output_classes: (Optional.) A nested structure of class types.
        If omitted, the class types will be inherited from `dataset`.
      allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
        reported output types and shapes of the restructured dataset, e.g. to
        switch a sparse tensor represented as `tf.variant` to its user-visible
        type and shape.

    Raises:
      ValueError: If either `output_types` or `output_shapes` is not compatible
        with the structure of `dataset`.
    """
    self._input_dataset = dataset

    input_types = dataset_ops.get_legacy_output_types(dataset)
    if not allow_unsafe_cast:
      # Validate that the types are compatible.
      output_types = nest.map_structure(dtypes.as_dtype, output_types)
      flat_original_types = nest.flatten(input_types)
      flat_new_types = nest.flatten(output_types)
      if flat_original_types != flat_new_types:
        raise ValueError(
            "Dataset with output types %r cannot be restructured to have "
            "output types %r" %
            (dataset_ops.get_legacy_output_types(dataset), output_types))

    input_shapes = dataset_ops.get_legacy_output_shapes(dataset)
    if output_shapes is None:
      # Inherit shapes from the original `dataset`.
      output_shapes = nest.pack_sequence_as(
          output_types, nest.flatten(input_shapes))
    else:
      if not allow_unsafe_cast:
        # Validate that the shapes are compatible.
        nest.assert_same_structure(output_types, output_shapes)
        flat_original_shapes = nest.flatten(input_shapes)
        flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)

        for original_shape, new_shape in zip(flat_original_shapes,
                                             flat_new_shapes):
          if not original_shape.is_compatible_with(new_shape):
            raise ValueError(
                "Dataset with output shapes %r cannot be restructured to have "
                "incompatible output shapes %r" % (input_shapes,
                                                   output_shapes))
      output_shapes = nest.map_structure_up_to(
          output_types, tensor_shape.as_shape, output_shapes)

    input_classes = dataset_ops.get_legacy_output_classes(dataset)
    if output_classes is None:
      # Inherit class types from the original `dataset`.
      output_classes = nest.pack_sequence_as(
          output_types, nest.flatten(input_classes))

    self._structure = structure.convert_legacy_structure(
        output_types, output_shapes, output_classes)
    variant_tensor = self._input_dataset._variant_tensor  # pylint: disable=protected-access
    super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:78,代码来源:batching.py



注:本文中的tensorflow.python.data.util.nest.map_structure_up_to函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python nest.pack_sequence_as函数代码示例发布时间:2022-05-27
下一篇:
Python nest.map_structure函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap