• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python dataset_ops.flat_structure函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.data.ops.dataset_ops.flat_structure函数的典型用法代码示例。如果您正苦于以下问题:Python flat_structure函数的具体用法?Python flat_structure怎么用?Python flat_structure使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了flat_structure函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: __init__

 def __init__(self, input_dataset):
   self._input_dataset = input_dataset
   temp_variant_tensor = gen_dataset_ops.prefetch_dataset(
       input_dataset._variant_tensor,
       buffer_size=1,
       **dataset_ops.flat_structure(self))
   variant_tensor = gen_dataset_ops.model_dataset(
       temp_variant_tensor, **dataset_ops.flat_structure(self))
   super(_TestDataset, self).__init__(input_dataset, variant_tensor)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:9,代码来源:input_ops_test.py


示例2: __init__

  def __init__(self, per_device_dataset, incarnation_id):
    # pylint: disable=protected-access
    self._structure = per_device_dataset._structure

    self._init_func = per_device_dataset._init_func
    self._init_captured_args = self._init_func.captured_inputs

    self._next_func = per_device_dataset._next_func
    self._next_captured_args = per_device_dataset._next_captured_args
    # The captured arguments to the next_func are string_handle, incarnation_id.
    # We update the incarnation id to the new one.
    self._next_captured_args[
        per_device_dataset._incarnation_id_index] = incarnation_id

    self._finalize_func = per_device_dataset._finalize_func
    self._finalize_captured_args = per_device_dataset._finalize_captured_args

    variant_tensor = gen_dataset_ops.generator_dataset(
        self._init_captured_args,
        self._next_captured_args,
        self._finalize_captured_args,
        init_func=self._init_func,
        next_func=self._next_func,
        finalize_func=self._finalize_func,
        **dataset_ops.flat_structure(self))
    super(_ReincarnatedPerDeviceGenerator, self).__init__(variant_tensor)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:26,代码来源:multi_device_iterator_ops.py


示例3: _as_variant_tensor

 def _as_variant_tensor(self):
   input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
   return gen_dataset_ops.map_dataset(
       input_t,
       self._map_func.captured_inputs,
       f=self._map_func,
       **dataset_ops.flat_structure(self))
开发者ID:bunbutter,项目名称:tensorflow,代码行数:7,代码来源:grouping.py


示例4: _as_variant_tensor

 def _as_variant_tensor(self):
   # pylint: disable=protected-access
   return (
       gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
           self._selector_input._variant_tensor,
           [data_input._variant_tensor for data_input in self._data_inputs],
           **dataset_ops.flat_structure(self)))
开发者ID:aritratony,项目名称:tensorflow,代码行数:7,代码来源:interleave_ops.py


示例5: __init__

  def __init__(self, datasets, num_experiments=10):
    """Chooses the fastest of some input datasets.

    Given input datasets, produces elements as quickly as the fastest of the
    inputs. Note that this dataset assumes that input datasets have the same
    elements in the same order, though this is not enforced besides checking
    that the input datasets have compatible output types, output shapes, and
    cardinality at runtime. The resulting dataset produces elements that are
    identical to the input elements, and in the same order.

    Note that the time to first iteration is longer when this dataset is used
    due to the overhead of dynamically picking the faster dataset. Namely,
    for the first num_experiments iterations, this dataset will pull from all
    of its inputs simultaneously in order to determine which input is the
    fastest. For all subsequent iterations, that input will be used.

    Args:
      datasets: A list of `Datasets` that all have the same elements in the same
        order.
      num_experiments: The number of experiments to run before deciding which
        dataset is fastest. In each "experiment" iteration, the dataset will
        call from all its inputs simultaneously, and update its knowledge of
        which input is the fastest.

    Returns:
      A `Dataset` that has the same elements the inputs.
    """
    self._datasets = list(datasets)
    self._structure = self._datasets[0]._element_structure  # pylint: disable=protected-access
    variant_tensor = (
        gen_experimental_dataset_ops.experimental_choose_fastest_dataset(
            [dataset._variant_tensor for dataset in self._datasets],  # pylint: disable=protected-access
            num_experiments=num_experiments,
            **dataset_ops.flat_structure(self)))
    super(_ChooseFastestDataset, self).__init__(variant_tensor)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:35,代码来源:optimization.py


示例6: __init__

  def __init__(self, input_dataset, num_workers):
    self._input_dataset = input_dataset

    def recalculate_output_shapes(output_shapes):
      """Recalculates the output_shapes after dividing it by num_workers."""
      if len(output_shapes) < 1:
        raise ValueError("Input shape should have at least one dimension.")
      if (tensor_shape.dimension_value(output_shapes[0]) and
          tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0):
        raise errors.InvalidArgumentError(
            None, None,
            "First dim of input shape: %d is not divisible by num_workers: %d" %
            (output_shapes[0], num_workers))
      output_dims = [d for d in output_shapes.dims]
      output_dims[0] = output_dims[0] // num_workers
      return tensor_shape.TensorShape(output_dims)

    input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
    input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset)
    input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
    output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)

    self._structure = structure.convert_legacy_structure(
        input_types, output_shapes, input_classes)
    variant_tensor = ged_ops.experimental_rebatch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        num_workers=num_workers,
        **dataset_ops.flat_structure(self))
    super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
开发者ID:aritratony,项目名称:tensorflow,代码行数:29,代码来源:distribute.py


示例7: _as_variant_tensor

 def _as_variant_tensor(self):
   return gen_dataset_ops.slide_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       window_size=self._window_size,
       window_shift=self._window_shift,
       window_stride=self._window_stride,
       **dataset_ops.flat_structure(self))
开发者ID:ChristinaEricka,项目名称:tensorflow,代码行数:7,代码来源:sliding.py


示例8: materialize

  def materialize(self, shared_name=None, container=None):
    """Materialize creates a MaterializedIndexedDataset.

    IndexedDatasets can be combined through operations such as TBD. Therefore,
    they are only materialized when absolutely required.

    Args:
      shared_name: a string for the shared name to use for the resource.
      container: a string for the container to store the resource.

    Returns:
      A MaterializedIndexedDataset.
    """
    if container is None:
      container = ""
    if shared_name is None:
      shared_name = ""
    materialized_resource = (
        ged_ops.experimental_materialized_index_dataset_handle(
            container=container,
            shared_name=shared_name,
            **dataset_ops.flat_structure(self)))

    with ops.colocate_with(materialized_resource):
      materializer = ged_ops.experimental_indexed_dataset_materialize(
          self._as_variant_tensor(), materialized_resource)
    return MaterializedIndexedDataset(materialized_resource, materializer,
                                      self.output_classes, self.output_types,
                                      self.output_shapes)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:29,代码来源:indexed_dataset_ops.py


示例9: __init__

  def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
               drop_remainder, use_legacy_function=False):
    """See `Dataset.map()` for details."""
    self._input_dataset = input_dataset

    self._map_func = dataset_ops.StructuredFunctionWrapper(
        map_func,
        "tf.data.experimental.map_and_batch()",
        dataset=input_dataset,
        use_legacy_function=use_legacy_function)
    self._batch_size_t = ops.convert_to_tensor(
        batch_size, dtype=dtypes.int64, name="batch_size")
    self._num_parallel_calls_t = ops.convert_to_tensor(
        num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
    self._drop_remainder_t = ops.convert_to_tensor(
        drop_remainder, dtype=dtypes.bool, name="drop_remainder")

    constant_drop_remainder = tensor_util.constant_value(self._drop_remainder_t)
    if constant_drop_remainder:
      # NOTE(mrry): `constant_drop_remainder` may be `None` (unknown statically)
      # or `False` (explicitly retaining the remainder).
      self._structure = self._map_func.output_structure._batch(  # pylint: disable=protected-access
          tensor_util.constant_value(self._batch_size_t))
    else:
      self._structure = self._map_func.output_structure._batch(None)  # pylint: disable=protected-access
    variant_tensor = ged_ops.experimental_map_and_batch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        self._map_func.function.captured_inputs,
        f=self._map_func.function,
        batch_size=self._batch_size_t,
        num_parallel_calls=self._num_parallel_calls_t,
        drop_remainder=self._drop_remainder_t,
        preserve_cardinality=True,
        **dataset_ops.flat_structure(self))
    super(_MapAndBatchDataset, self).__init__(input_dataset, variant_tensor)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:35,代码来源:batching.py


示例10: _as_variant_tensor

 def _as_variant_tensor(self):
   return gen_dataset_ops.set_stats_aggregator_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       self._stats_aggregator._resource,  # pylint: disable=protected-access
       self._tag,
       self._prefix,
       **dataset_ops.flat_structure(self))
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:7,代码来源:stats_ops.py


示例11: _as_variant_tensor

 def _as_variant_tensor(self):
   return ged_ops.experimental_sliding_window_dataset(
       self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
       window_size=self._window_size,
       window_shift=self._window_shift,
       window_stride=self._window_stride,
       **dataset_ops.flat_structure(structure=self._output_structure))
开发者ID:aeverall,项目名称:tensorflow,代码行数:7,代码来源:sliding.py


示例12: __init__

  def __init__(self, input_dataset, features, num_parallel_calls):
    self._input_dataset = input_dataset
    if not input_dataset._element_structure.is_compatible_with(  # pylint: disable=protected-access
        structure.TensorStructure(dtypes.string, [None])):
      raise TypeError("Input dataset should be a dataset of vectors of strings")
    self._num_parallel_calls = num_parallel_calls
    # pylint: disable=protected-access
    self._features = parsing_ops._prepend_none_dimension(features)
    # sparse_keys and dense_keys come back sorted here.
    (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
     dense_shapes) = parsing_ops._features_to_raw_params(
         self._features, [
             parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
             parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
         ])
    # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
    (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
     dense_shape_as_shape) = parsing_ops._process_raw_parameters(
         None, dense_defaults, sparse_keys, sparse_types, dense_keys,
         dense_types, dense_shapes)
    # pylint: enable=protected-access
    self._sparse_keys = sparse_keys
    self._sparse_types = sparse_types
    self._dense_keys = dense_keys
    self._dense_defaults = dense_defaults_vec
    self._dense_shapes = dense_shapes
    self._dense_types = dense_types
    input_dataset_shape = dataset_ops.get_legacy_output_shapes(
        self._input_dataset)
    dense_output_shapes = [input_dataset_shape.concatenate(shape)
                           for shape in dense_shape_as_shape]
    sparse_output_shapes = [input_dataset_shape.concatenate([None])
                            for _ in range(len(sparse_keys))]

    output_shapes = dict(
        zip(self._dense_keys + self._sparse_keys,
            dense_output_shapes + sparse_output_shapes))
    output_types = dict(
        zip(self._dense_keys + self._sparse_keys,
            self._dense_types + self._sparse_types))
    output_classes = dict(
        zip(self._dense_keys + self._sparse_keys,
            [ops.Tensor for _ in range(len(self._dense_defaults))] +
            [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
            ]))
    self._structure = structure.convert_legacy_structure(
        output_types, output_shapes, output_classes)

    variant_tensor = (
        gen_experimental_dataset_ops.experimental_parse_example_dataset(
            self._input_dataset._variant_tensor,  # pylint: disable=protected-access
            self._num_parallel_calls,
            self._dense_defaults,
            self._sparse_keys,
            self._dense_keys,
            self._sparse_types,
            self._dense_shapes,
            **dataset_ops.flat_structure(self)))
    super(_ParseExampleDataset, self).__init__(input_dataset, variant_tensor)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:59,代码来源:parsing_ops.py


示例13: __init__

 def __init__(self, input_dataset, sleep_microseconds):
   self._input_dataset = input_dataset
   self._sleep_microseconds = sleep_microseconds
   variant_tensor = gen_experimental_dataset_ops.experimental_sleep_dataset(
       self._input_dataset._variant_tensor,  # pylint: disable=protected-access
       self._sleep_microseconds,
       **dataset_ops.flat_structure(self))
   super(_SleepDataset, self).__init__(input_dataset, variant_tensor)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:8,代码来源:sleep.py


示例14: _as_variant_tensor

 def _as_variant_tensor(self):
   input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
   return gen_dataset_ops.scan_dataset(
       input_t,
       nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
       self._scan_func.captured_inputs,
       f=self._scan_func,
       **dataset_ops.flat_structure(self))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:8,代码来源:scan_ops.py


示例15: __init__

 def __init__(self, input_dataset, thread_pool):
   self._input_dataset = input_dataset
   self._thread_pool = thread_pool
   variant_tensor = ged_ops.experimental_thread_pool_dataset(
       self._input_dataset._variant_tensor,  # pylint: disable=protected-access
       self._thread_pool._resource,  # pylint: disable=protected-access
       **dataset_ops.flat_structure(self))
   super(_ThreadPoolDataset, self).__init__(input_dataset, variant_tensor)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:8,代码来源:threadpool.py


示例16: _as_variant_tensor

 def _as_variant_tensor(self):
   input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
   return ged_ops.experimental_map_dataset(
       input_t,
       self._map_func.function.captured_inputs,
       f=self._map_func.function,
       use_inter_op_parallelism=self._use_inter_op_parallelism,
       **dataset_ops.flat_structure(self))
开发者ID:aeverall,项目名称:tensorflow,代码行数:8,代码来源:prefetching_ops.py


示例17: __init__

 def __init__(self, input_dataset):
   """See `Dataset.ignore_errors()` for details."""
   self._input_dataset = input_dataset
   variant_tensor = (
       gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
           self._input_dataset._variant_tensor,  # pylint: disable=protected-access
           **dataset_ops.flat_structure(self)))
   super(_IgnoreErrorsDataset, self).__init__(input_dataset, variant_tensor)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:8,代码来源:error_ops.py


示例18: __init__

  def __init__(self, input_dataset, path):
    self._input_dataset = input_dataset
    self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")

    variant_tensor = ged_ops.snapshot_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        path=self._path,
        **dataset_ops.flat_structure(self))
    super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor)
开发者ID:aritratony,项目名称:tensorflow,代码行数:9,代码来源:snapshot.py


示例19: __init__

 def __init__(self, input_dataset, op_function, tag):
   self._input_dataset = input_dataset
   self._op_function = op_function
   self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string)
   variant_tensor = self._op_function(
       self._input_dataset._variant_tensor,  # pylint: disable=protected-access
       self._tag,
       **dataset_ops.flat_structure(self))
   super(_StatsDataset, self).__init__(input_dataset, variant_tensor)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:9,代码来源:stats_ops.py


示例20: _as_variant_tensor

 def _as_variant_tensor(self):
   # pylint: disable=protected-access
   return ged_ops.experimental_map_and_batch_dataset(
       self._input_dataset._as_variant_tensor(),
       self._map_func.function.captured_inputs,
       f=self._map_func.function,
       batch_size=self._batch_size_t,
       num_parallel_calls=self._num_parallel_calls_t,
       drop_remainder=self._drop_remainder_t,
       **dataset_ops.flat_structure(structure=self._output_structure))
开发者ID:aeverall,项目名称:tensorflow,代码行数:10,代码来源:batching.py



注:本文中的tensorflow.python.data.ops.dataset_ops.flat_structure函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python dataset_ops.get_legacy_output_shapes函数代码示例发布时间:2022-05-27
下一篇:
Python prefetching_ops.copy_to_device函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap