• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python dataset_ops.make_one_shot_iterator函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.data.ops.dataset_ops.make_one_shot_iterator函数的典型用法代码示例。如果您正苦于以下问题:Python make_one_shot_iterator函数的具体用法?Python make_one_shot_iterator怎么用?Python make_one_shot_iterator使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了make_one_shot_iterator函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: _compare

  def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id):
    num_elems = int(np.sum([np.prod(x) for x in input_size]))
    name_template = "{}_batch_size_{}_input_element_size_{}_{}"

    unoptimized_dataset = input_dataset.map(map_fn).batch(batch_size)

    options = dataset_ops.Options()
    options.experimental_optimization.apply_default_optimizations = False
    unoptimized_dataset = unoptimized_dataset.with_options(options)
    unoptimized_next = dataset_ops.make_one_shot_iterator(
        unoptimized_dataset).get_next()

    options = dataset_ops.Options()
    options.experimental_optimization.map_vectorization = True
    optimized_dataset = unoptimized_dataset.with_options(options)
    optimized_next = dataset_ops.make_one_shot_iterator(
        optimized_dataset).get_next()

    unoptimized_time = self._run(
        unoptimized_next,
        name=name_template.format(str_id, batch_size, num_elems, "unoptimized"))
    optimized_time = self._run(
        optimized_next,
        name=name_template.format(str_id, batch_size, num_elems, "optimized"))

    print("Batch size: {}\n"
          "Input element size: {}\n"
          "Transformation: {}\n"
          "Speedup: {}\n".format(batch_size, input_size, str_id,
                                 (unoptimized_time / optimized_time)))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:30,代码来源:map_vectorization_benchmark.py


示例2: test_metrics_correctness_with_iterator

  def test_metrics_correctness_with_iterator(self):
    layers = [
        keras.layers.Dense(8, activation='relu', input_dim=4,
                           kernel_initializer='ones'),
        keras.layers.Dense(1, activation='sigmoid', kernel_initializer='ones')
    ]

    model = testing_utils.get_model_from_layers(layers, (4,))

    model.compile(
        loss='binary_crossentropy',
        metrics=['accuracy', metrics_module.BinaryAccuracy()],
        optimizer='rmsprop',
        run_eagerly=testing_utils.should_run_eagerly())

    np.random.seed(123)
    x = np.random.randint(10, size=(100, 4)).astype(np.float32)
    y = np.random.randint(2, size=(100, 1)).astype(np.float32)
    dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
    dataset = dataset.batch(10)
    iterator = dataset_ops.make_one_shot_iterator(dataset)
    outs = model.evaluate(iterator, steps=10)
    self.assertEqual(np.around(outs[1], decimals=1), 0.5)
    self.assertEqual(np.around(outs[2], decimals=1), 0.5)

    y = np.zeros((100, 1), dtype=np.float32)
    dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
    dataset = dataset.repeat(100)
    dataset = dataset.batch(10)
    iterator = dataset_ops.make_one_shot_iterator(dataset)
    outs = model.evaluate(iterator, steps=10)
    self.assertEqual(outs[1], 0.)
    self.assertEqual(outs[2], 0.)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:33,代码来源:training_dataset_test.py


示例3: testSaveRestoreMultipleIterator

 def testSaveRestoreMultipleIterator(self):
   checkpoint_directory = self.get_temp_dir()
   checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
   dataset = dataset_ops.Dataset.from_tensor_slices(
       [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
   dataset = dataset.map(math_ops.square).batch(2)
   iterator_1 = iter(dataset) if context.executing_eagerly(
   ) else dataset_ops.make_one_shot_iterator(dataset)
   get_next_1 = iterator_1.get_next if context.executing_eagerly(
   ) else functools.partial(self.evaluate, iterator_1.get_next())
   iterator_2 = iter(dataset) if context.executing_eagerly(
   ) else dataset_ops.make_one_shot_iterator(dataset)
   get_next_2 = iterator_2.get_next if context.executing_eagerly(
   ) else functools.partial(self.evaluate, iterator_2.get_next())
   dataset_2 = dataset_ops.Dataset.range(10)
   iterator_3 = iter(dataset_2) if context.executing_eagerly(
   ) else dataset_ops.make_one_shot_iterator(dataset_2)
   get_next_3 = iterator_3.get_next if context.executing_eagerly(
   ) else functools.partial(self.evaluate, iterator_3.get_next())
   checkpoint = trackable_utils.Checkpoint(
       iterator_1=iterator_1, iterator_2=iterator_2, iterator_3=iterator_3)
   self.assertAllEqual([1, 4], get_next_1())
   self.assertAllEqual(0, get_next_3())
   self.assertAllEqual(1, get_next_3())
   self.assertAllEqual(2, get_next_3())
   save_path = checkpoint.save(checkpoint_prefix)
   self.assertAllEqual([1, 4], get_next_2())
   self.assertAllEqual([9, 16], get_next_2())
   self.assertAllEqual(3, get_next_3())
   checkpoint.restore(save_path).run_restore_ops()
   self.assertAllEqual([9, 16], get_next_1())
   self.assertAllEqual([1, 4], get_next_2())
   self.assertAllEqual(3, get_next_3())
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:33,代码来源:iterator_checkpoint_test.py


示例4: test_model_fit_and_validation_with_missing_arg_errors

  def test_model_fit_and_validation_with_missing_arg_errors(self):
    model = testing_utils.get_small_mlp(10, 4, 3)
    model.compile(optimizer=rmsprop.RMSprop(learning_rate=0.001),
                  loss='mse',
                  run_eagerly=True)

    x = array_ops.zeros(shape=(10, 3))
    y = array_ops.zeros(shape=(10, 4))
    dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat(10).batch(5)
    iterator = dataset_ops.make_one_shot_iterator(dataset)
    validation_dataset = dataset_ops.Dataset.from_tensor_slices(
        (x, y)).repeat().batch(5)  # Infinite dataset.
    validation_iterator = dataset_ops.make_one_shot_iterator(validation_dataset)

    with self.assertRaisesRegexp(
        ValueError, r'specify .* `steps_per_epoch`'):
      model.fit(iterator, epochs=1, verbose=0)
    if not context.executing_eagerly():
      # In eager execution, `array_ops.zeros` returns value tensors
      # which can be used for validation without a `validation_steps` argument.
      with self.assertRaisesRegexp(
          ValueError, r'provide either `batch_size` or `validation_steps`'):
        model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0,
                  validation_data=(x, y))
    # Step argument is required for infinite datasets.
    with self.assertRaisesRegexp(ValueError,
                                 'specify the `validation_steps` argument.'):
      model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0,
                validation_data=validation_dataset)
    with self.assertRaisesRegexp(ValueError,
                                 'specify the `validation_steps` argument.'):
      model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0,
                validation_data=validation_iterator)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:33,代码来源:training_eager_test.py


示例5: testMapNamedtuple

  def testMapNamedtuple(self, count=10):
    # construct dataset of tuples
    labels = dataset_ops.Dataset.range(count)
    images = labels.map(lambda l: -l)
    dataset_tuple = dataset_ops.Dataset.zip((labels, images))

    # convert dataset of tuples to dataset of namedtuples
    example = namedtuple("Example", ["label", "image"])
    dataset_namedtuple = dataset_tuple.map(example)

    def preprocess_tuple(label, image):
      image = 2 * image
      return label, image

    def preprocess_namedtuple(example):
      return example._replace(image=2 * example.image)

    # preprocess both datasets
    dataset_tuple = dataset_tuple.map(preprocess_tuple)
    dataset_namedtuple = dataset_namedtuple.map(preprocess_namedtuple)

    next_tuple = dataset_ops.make_one_shot_iterator(dataset_tuple).get_next()
    next_namedtuple = dataset_ops.make_one_shot_iterator(
        dataset_namedtuple).get_next()

    # make sure both datasets contain the same data
    with self.cached_session() as sess:
      for i in range(count):
        tuple_, namedtuple_ = sess.run([next_tuple, next_namedtuple])
        self.assertEqual(tuple_, namedtuple_)
        self.assertEqual(tuple_, (i, -2 * i))

      with self.assertRaises(errors.OutOfRangeError):
        sess.run(next_namedtuple)
开发者ID:aeverall,项目名称:tensorflow,代码行数:34,代码来源:map_test.py


示例6: testCapturingStateInOneShotRaisesException

 def testCapturingStateInOneShotRaisesException(self):
   var = variables.Variable(37.0, name="myvar")
   dataset = (
       dataset_ops.Dataset.from_tensor_slices([0.0, 1.0, 2.0])
       .map(lambda x: x + var))
   with self.assertRaisesRegexp(
       ValueError, r"`Dataset.make_one_shot_iterator\(\)` does not support "
       "datasets that capture stateful objects.+myvar"):
     dataset_ops.make_one_shot_iterator(dataset)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:9,代码来源:iterator_test.py


示例7: testIteratorStringHandle

  def testIteratorStringHandle(self):
    dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
    dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])

    iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
    iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

    handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    feedable_iterator = iterator_ops.Iterator.from_string_handle(
        handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
        dataset_ops.get_legacy_output_shapes(dataset_3))
    next_element = feedable_iterator.get_next()

    self.assertTrue(dataset_ops.get_structure(dataset_3).is_compatible_with(
        dataset_ops.get_structure(feedable_iterator)))
    self.assertTrue(dataset_ops.get_structure(dataset_4).is_compatible_with(
        dataset_ops.get_structure(feedable_iterator)))

    with self.cached_session() as sess:
      iterator_3_handle = sess.run(iterator_3.string_handle())
      iterator_4_handle = sess.run(iterator_4.string_handle())

      self.assertEqual(10,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(1,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(20,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(2,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(30,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(3,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(40,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(
            next_element, feed_dict={handle_placeholder: iterator_3_handle})
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(
            next_element, feed_dict={handle_placeholder: iterator_4_handle})
开发者ID:kylin9872,项目名称:tensorflow,代码行数:56,代码来源:iterator_test.py


示例8: testSkipEagerMultipleIterators

  def testSkipEagerMultipleIterators(self, reshuffle, initializable):
    with ops.Graph().as_default() as g:
      dataset = dataset_ops.Dataset.range(100).shuffle(
          10, reshuffle_each_iteration=reshuffle).repeat(3)

      if initializable:
        iterators = [dataset_ops.make_initializable_iterator(dataset)
                     for _ in range(2)]
      else:
        iterators = [dataset_ops.make_one_shot_iterator(dataset)
                     for _ in range(2)]

      results = []
      with self.session(graph=g) as sess:
        for iterator in iterators:
          if initializable:
            sess.run(iterator.initializer)
          next_element = iterator.get_next()
          run_results = []
          for _ in range(300):
            run_results.append(sess.run(next_element))
          with self.assertRaises(errors.OutOfRangeError):
            sess.run(next_element)

          results.append(run_results)

        self.assertNotEqual(results[0], results[1])
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:27,代码来源:shuffle_test.py


示例9: testFromGenerator

  def testFromGenerator(self):
    test_cases = [{
        'tensor': 0,
        'shape': tensor_shape.TensorShape([])
    }, {
        'tensor': np.array([1, 2, 3]),
        'shape': tensor_shape.TensorShape([3])
    }, {
        'tensor': np.array([[1, 2, 3]]),
        'shape': tensor_shape.TensorShape([1, 3])
    }]

    for test_case in test_cases:

      def make_generator(tensor):

        def generator():
          yield tensor

        return generator

      with ops.Graph().as_default() as g:
        dataset = dataset_ops.Dataset.from_generator(
            make_generator(test_case['tensor']),
            dtypes.int64,
            output_shapes=test_case['shape'])
        iterator = dataset_ops.make_one_shot_iterator(dataset)
        get_next = iterator.get_next()
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(get_next)
        mg = meta_graph.create_meta_graph_def(graph=g)
        grappler_item = item.Item(mg)
        op_properties = grappler_item.GetOpProperties()
        self.assertEqual(test_case['shape'],
                         op_properties['IteratorGetNext'][0].shape)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:35,代码来源:datasets_test.py


示例10: testOneShotIteratorCaptureByValue

  def testOneShotIteratorCaptureByValue(self):
    components = (np.arange(7),
                  np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
                  np.array(37.0) * np.arange(7))
    tensor_components = tuple([ops.convert_to_tensor(c) for c in components])

    def _map_fn(x, y, z):
      return math_ops.square(x), math_ops.square(y), math_ops.square(z)

    iterator = dataset_ops.make_one_shot_iterator(
        dataset_ops.Dataset.from_tensor_slices(tensor_components)
        .map(_map_fn).repeat(14))
    get_next = iterator.get_next()

    self.assertEqual([c.shape[1:] for c in components],
                     [t.shape for t in get_next])

    with self.cached_session() as sess:
      for _ in range(14):
        for i in range(7):
          result = sess.run(get_next)
          for component, result_component in zip(components, result):
            self.assertAllEqual(component[i]**2, result_component)
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(get_next)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:25,代码来源:iterator_test.py


示例11: getNext

  def getNext(self, dataset, requires_initialization=False):
    """Returns a callable that returns the next element of the dataset.

    Example use:
    ```python
    # In both graph and eager modes
    dataset = ...
    get_next = self.getNext(dataset)
    result = self.evaluate(get_next())
    ```

    Args:
      dataset: A dataset whose elements will be returned.
      requires_initialization: Indicates that when the test is executed in graph
        mode, it should use an initializable iterator to iterate through the
        dataset (e.g. when it contains stateful nodes). Defaults to False.
    Returns:
      A callable that returns the next element of `dataset`.
    """
    if context.executing_eagerly():
      iterator = dataset.__iter__()
      return iterator._next_internal  # pylint: disable=protected-access
    else:
      if requires_initialization:
        iterator = dataset_ops.make_initializable_iterator(dataset)
        self.evaluate(iterator.initializer)
      else:
        iterator = dataset_ops.make_one_shot_iterator(dataset)
      get_next = iterator.get_next()
      return lambda: get_next
开发者ID:aeverall,项目名称:tensorflow,代码行数:30,代码来源:test_base.py


示例12: testMapAndBatchOutOfRangeError

  def testMapAndBatchOutOfRangeError(self, threshold, numa_aware):

    def raising_py_fn(i):
      if i == threshold:
        raise StopIteration()
      elif i > threshold:
        raise RuntimeError("Alternate error; you shouldn't see me! (i: %s)" % i)
      else:
        return i

    dataset = dataset_ops.Dataset.range(100).apply(
        batching.map_and_batch(
            lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
            batch_size=10))
    if numa_aware:
      options = dataset_ops.Options()
      options.experimental_numa_aware = True
      dataset = dataset.with_options(options)
    iterator = dataset_ops.make_one_shot_iterator(dataset)
    get_next = iterator.get_next()

    with self.cached_session() as sess:
      for i in range(threshold // 10):
        self.assertAllEqual([i * 10 + j for j in range(10)],
                            self.evaluate(get_next))
      if threshold % 10 != 0:
        self.assertAllEqual(
            [threshold // 10 * 10 + j for j in range(threshold % 10)],
            self.evaluate(get_next))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(get_next)
开发者ID:aeverall,项目名称:tensorflow,代码行数:31,代码来源:map_and_batch_test.py


示例13: test_sequential_deferred_build_with_dataset_iterators

  def test_sequential_deferred_build_with_dataset_iterators(self):
    num_hidden = 5
    input_dim = 3
    num_classes = 2
    num_samples = 50
    steps_per_epoch = 10

    model = testing_utils.get_small_sequential_mlp(num_hidden, num_classes)
    model.compile(
        loss='mse',
        optimizer='rmsprop',
        metrics=[keras.metrics.CategoricalAccuracy()],
        run_eagerly=testing_utils.should_run_eagerly())
    self.assertEqual(len(model.layers), 2)
    self.assertEqual(len(model.weights), 0)
    self.assertFalse(model.built)

    x = array_ops.ones((num_samples, input_dim))
    y = array_ops.zeros((num_samples, num_classes))
    dataset = dataset_ops.Dataset.from_tensor_slices((x, y))
    dataset = dataset.repeat(100)
    dataset = dataset.batch(10)
    iterator = dataset_ops.make_one_shot_iterator(dataset)

    model.fit(iterator, epochs=1, steps_per_epoch=steps_per_epoch)
    self.assertTrue(model.built)
    self.assertEqual(len(model.weights), 2 * 2)
    self.assertFalse(model._is_graph_network)
开发者ID:gautam1858,项目名称:tensorflow,代码行数:28,代码来源:sequential_test.py


示例14: testMapAndBatchParallelGetNextDropRemainder

  def testMapAndBatchParallelGetNextDropRemainder(self, numa_aware):
    dataset = dataset_ops.Dataset.range(49999).apply(
        batching.map_and_batch(
            lambda x: x, batch_size=100, drop_remainder=True))

    if numa_aware:
      options = dataset_ops.Options()
      options.experimental_numa_aware = True
      dataset = dataset.with_options(options)

    if context.executing_eagerly():
      iterator = iter(dataset)
      get_next = iterator._next_internal  # pylint: disable=protected-access
    else:
      iterator = dataset_ops.make_one_shot_iterator(dataset)
      get_next = iterator.get_next

    elements = []
    for _ in range(100):
      elements.append(get_next)

    for i in range(4):
      got = self.evaluate([element() for element in elements])
      got.sort(key=lambda x: x[0])
      expected = []
      for j in range(100):
        expected.append(range(i * 10000 + j * 100, i * 10000 + (j + 1) * 100))
      self.assertAllEqual(got, expected)
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate([element() for element in elements])
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:30,代码来源:map_and_batch_test.py


示例15: _test_bucket_by_padding

    def _test_bucket_by_padding(no_padding):
      dataset = build_dataset(sparse=no_padding)
      dataset = dataset.apply(
          grouping.bucket_by_sequence_length(
              _element_length_fn,
              boundaries,
              batch_sizes,
              no_padding=no_padding))
      batch, = dataset_ops.make_one_shot_iterator(dataset).get_next()

      with self.cached_session() as sess:
        batches = []
        for _ in range(4):
          batches.append(self.evaluate(batch))
        with self.assertRaises(errors.OutOfRangeError):
          self.evaluate(batch)
      batch_sizes_val = []
      lengths_val = []
      for batch in batches:
        shape = batch.dense_shape if no_padding else batch.shape
        batch_size = shape[0]
        length = shape[1]
        batch_sizes_val.append(batch_size)
        lengths_val.append(length)
        sum_check = batch.values.sum() if no_padding else batch.sum()
        self.assertEqual(sum_check, batch_size * length - 1)
      self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
      self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
      self.assertEqual(sorted(lengths), sorted(lengths_val))
开发者ID:aeverall,项目名称:tensorflow,代码行数:29,代码来源:bucket_by_sequence_length_test.py


示例16: testPadToBoundaryNoExtraneousPadding

  def testPadToBoundaryNoExtraneousPadding(self):

    boundaries = [3, 7, 11]
    batch_sizes = [2, 2, 2, 2]
    lengths = range(1, 11)

    def element_gen():
      for length in lengths:
        yield ([1] * length,)

    element_len = lambda element: array_ops.shape(element)[0]
    dataset = dataset_ops.Dataset.from_generator(
        element_gen, (dtypes.int64,), ([None],)).apply(
            grouping.bucket_by_sequence_length(
                element_len, boundaries, batch_sizes,
                pad_to_bucket_boundary=True))
    batch, = dataset_ops.make_one_shot_iterator(dataset).get_next()

    with self.cached_session() as sess:
      batches = []
      for _ in range(5):
        batches.append(self.evaluate(batch))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(batch)

    self.assertAllEqual(batches[0], [[1, 0],
                                     [1, 1]])
    self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
                                     [1, 1, 1, 1, 0, 0]])
    self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
                                     [1, 1, 1, 1, 1, 1]])
    self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
                                     [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
    self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
                                     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
开发者ID:aeverall,项目名称:tensorflow,代码行数:35,代码来源:bucket_by_sequence_length_test.py


示例17: testPrefetchToSameDevice

  def testPrefetchToSameDevice(self):
    host_dataset = dataset_ops.Dataset.range(10)
    device_dataset = host_dataset.apply(
        prefetching_ops.prefetch_to_device(
            "/job:localhost/replica:0/task:0/device:CPU:0"))

    with ops.device("/cpu:1"):
      iterator = dataset_ops.make_one_shot_iterator(device_dataset)
      next_element = iterator.get_next()

    self.assertEqual(host_dataset.output_types, device_dataset.output_types)
    self.assertEqual(host_dataset.output_types, iterator.output_types)
    self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
    self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
    self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
    self.assertEqual(host_dataset.output_classes, iterator.output_classes)

    self.assertEqual(dtypes.int64, next_element.dtype)
    self.assertEqual([], next_element.shape)

    with self.cached_session():
      for i in range(10):
        self.assertEqual(i, self.evaluate(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_element)
开发者ID:aeverall,项目名称:tensorflow,代码行数:25,代码来源:prefetch_to_device_test.py


示例18: testMapAndBatchPartialBatch

  def testMapAndBatchPartialBatch(self, drop_remainder, numa_aware):
    dataset = (
        dataset_ops.Dataset.range(10).apply(
            batching.map_and_batch(
                lambda x: array_ops.reshape(x * x, [1]),
                batch_size=4,
                drop_remainder=drop_remainder)))

    if numa_aware:
      options = dataset_ops.Options()
      options.experimental_numa_aware = True
      dataset = dataset.with_options(options)
    iterator = dataset_ops.make_one_shot_iterator(dataset)

    if drop_remainder:
      self.assertEqual([4, 1], iterator.output_shapes.as_list())
    else:
      self.assertEqual([None, 1], iterator.output_shapes.as_list())
    next_element = iterator.get_next()
    with self.cached_session() as sess:
      self.assertAllEqual([[0], [1], [4], [9]], self.evaluate(next_element))
      self.assertAllEqual([[16], [25], [36], [49]], self.evaluate(next_element))
      if not drop_remainder:
        self.assertAllEqual([[64], [81]], self.evaluate(next_element))
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_element)
开发者ID:aeverall,项目名称:tensorflow,代码行数:26,代码来源:map_and_batch_test.py


示例19: testMapAndBatchImplicitDispose

  def testMapAndBatchImplicitDispose(self, numa_aware):
    # Tests whether a map and batch dataset will be cleaned up correctly when
    # the pipeline does not run it until exhaustion.
    # The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
    # MapAndBatchDataset(f=square_3, batch_size=100).
    components = (np.arange(1000),
                  np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
                  np.array(37.0) * np.arange(1000))

    def _map_fn(x, y, z):
      return math_ops.square(x), math_ops.square(y), math_ops.square(z)

    dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
        1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
    dataset = dataset.prefetch(5)
    if numa_aware:
      options = dataset_ops.Options()
      options.experimental_numa_aware = True
      dataset = dataset.with_options(options)
    iterator = dataset_ops.make_one_shot_iterator(dataset)
    get_next = iterator.get_next()

    with self.cached_session() as sess:
      for _ in range(3):
        self.evaluate(get_next)
开发者ID:aeverall,项目名称:tensorflow,代码行数:25,代码来源:map_and_batch_test.py


示例20: _benchmarkFilters

  def _benchmarkFilters(self, chain_length, optimize_dataset):
    with ops.Graph().as_default():
      dataset = dataset_ops.Dataset.from_tensors(5).repeat(None)
      for _ in range(chain_length):
        dataset = dataset.filter(lambda x: math_ops.greater_equal(x - 5, 0))
      if optimize_dataset:
        dataset = dataset.apply(optimization.optimize(["filter_fusion"]))

      iterator = dataset_ops.make_one_shot_iterator(dataset)
      next_element = iterator.get_next()

      with session.Session() as sess:
        for _ in range(10):
          self.evaluate(next_element.op)
        deltas = []
        for _ in range(100):
          start = time.time()
          for _ in range(100):
            self.evaluate(next_element.op)
          end = time.time()
          deltas.append(end - start)

        median_wall_time = np.median(deltas) / 100
        opt_mark = "opt" if optimize_dataset else "no-opt"
        print("Filter dataset {} chain length: {} Median wall time: {}".format(
            opt_mark, chain_length, median_wall_time))
        self.report_benchmark(
            iters=1000,
            wall_time=median_wall_time,
            name="benchmark_filter_dataset_chain_latency_{}_{}".format(
                opt_mark, chain_length))
开发者ID:aeverall,项目名称:tensorflow,代码行数:31,代码来源:filter_dataset_op_test.py



注:本文中的tensorflow.python.data.ops.dataset_ops.make_one_shot_iterator函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap