• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python nest.flatten函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.util.nest.flatten函数的典型用法代码示例。如果您正苦于以下问题:Python flatten函数的具体用法?Python flatten怎么用?Python flatten使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了flatten函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: convert_to_generator_like

def convert_to_generator_like(data,
                              batch_size=None,
                              steps_per_epoch=None,
                              epochs=1,
                              shuffle=False):
  """Make a generator out of NumPy or EagerTensor inputs.

  Arguments:
    data: Either a generator or `keras.utils.data_utils.Sequence` object or
      `Dataset` or `EagerIterator` or a {1,2,3}-tuple of NumPy arrays or
      EagerTensors. If a tuple, the elements represent `(x, y, sample_weights)`
      and may be `None` or `[None]`.
    batch_size: Used when creating a generator out of tuples of NumPy arrays or
      EagerTensors.
    steps_per_epoch: Steps of the generator to run each epoch.
    epochs: Total number of epochs to run.
    shuffle: Whether the data should be shuffled.

  Returns:
    - Generator or `keras.utils.data_utils.Sequence` or EagerIterator.

  Raises:
    - ValueError: If `batch_size` is not provided for NumPy or EagerTensor
      inputs.
  """
  if isinstance(data, tuple):
    # Scrub `Nones` that might have been passed for `targets`, `sample_weights`.
    data = tuple(
        ele for ele in data if not all(e is None for e in nest.flatten(ele)))
    if len(data) == 1:
      data = data[0]

  if data_utils.is_generator_or_sequence(data) or isinstance(
      data, iterator_ops.EagerIterator):
    if isinstance(data, data_utils.Sequence):
      steps_per_epoch = len(data)
    return data, steps_per_epoch
  if isinstance(data, dataset_ops.DatasetV2):
    return dataset_ops.make_one_shot_iterator(data), steps_per_epoch

  # Create generator from NumPy or EagerTensor Input.
  num_samples = int(nest.flatten(data)[0].shape[0])
  if batch_size is None:
    raise ValueError('You must specify `batch_size`')
  steps_per_epoch = int(math.ceil(num_samples / batch_size))

  def _gen(data):
    """Makes a generator out of a structure of NumPy/EagerTensors."""
    index_array = np.arange(num_samples)
    for _ in range(epochs):
      if shuffle:
        np.random.shuffle(index_array)
      batches = generic_utils.make_batches(num_samples, batch_size)
      for (batch_start, batch_end) in batches:
        batch_ids = index_array[batch_start:batch_end]
        flat_batch_data = training_utils.slice_arrays(
            nest.flatten(data), batch_ids, contiguous=(not shuffle))
        yield nest.pack_sequence_as(data, flat_batch_data)

  return _gen(data), steps_per_epoch
开发者ID:aeverall,项目名称:tensorflow,代码行数:60,代码来源:training_generator.py


示例2: _eager_metrics_fn

def _eager_metrics_fn(model,
                      outputs,
                      targets,
                      sample_weights=None,
                      masks=None,
                      return_stateful_result=True):
  """Calculates the metrics for each output of the given model.

  Arguments:
      model: The model on which metrics are being calculated.
      outputs: The outputs of the given model.
      targets: The predictions or targets of the given model.
      sample_weights: Optional list of sample weights for each output.
      masks: Optional list of masks for each output.
      return_stateful_result: Boolean, indicates whether the stateful
        (aggregated)/stateless metric result should be returned.

  Returns:
      Returns the metric results for each output of the model.
  """
  outputs = nest.flatten(outputs)
  targets = nest.flatten(targets)
  # TODO(psv): Consider supporting skip target indices in eager mode?
  metric_results = model._handle_metrics(
      outputs,
      targets=targets,
      sample_weights=sample_weights,
      masks=masks,
      return_stateful_result=return_stateful_result)
  return [backend.mean(t) for t in metric_results]
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:30,代码来源:training_eager.py


示例3: _create_multi_lstm_cell_ops

def _create_multi_lstm_cell_ops(batch_size, num_units, input_depth,
                                num_layers, max_time, compiled):
  with variable_scope.variable_scope(
      "root",
      initializer=init_ops.random_uniform_initializer(-0.1, 0.1, seed=2)):
    inputs = variable_scope.get_variable(
        "inputs", initializer=random_ops.random_uniform(
            (max_time, batch_size, input_depth), seed=1))
    maybe_xla = lambda c: rnn_cell.CompiledWrapper(c) if compiled else c
    cell = core_rnn_cell_impl.MultiRNNCell(
        [maybe_xla(core_rnn_cell_impl.LSTMCell(num_units))
         for _ in range(num_layers)])
    initial_state = cell.zero_state(
        batch_size=batch_size, dtype=dtypes.float32)
    outputs, final_state = rnn.dynamic_rnn(
        cell=cell, inputs=inputs, initial_state=initial_state,
        time_major=True)
    flat_final_state = nest.flatten(final_state)
    trainable_variables = variables.trainable_variables()
    outputs_grad = gradients_impl.gradients(
        [outputs],
        trainable_variables + [inputs] + nest.flatten(initial_state))
    final_state_grad = gradients_impl.gradients(
        flat_final_state,
        trainable_variables + [inputs] + nest.flatten(initial_state))

    return {"outputs": outputs,
            "final_state": flat_final_state,
            "outputs_grad": outputs_grad,
            "final_state_grad": final_state_grad}
开发者ID:Jackhuang945,项目名称:tensorflow,代码行数:30,代码来源:rnn_cell_test.py


示例4: call

  def call(self, inputs, mask=None):
    """Call the model on new inputs.

    In this case `call` just reapplies
    all ops in the graph to the new inputs
    (e.g. build a new computational graph from the provided inputs).

    Arguments:
        inputs: A tensor or list of tensors.
        mask: A mask or list of masks. A mask can be
            either a tensor or None (no mask).

    Returns:
        A tensor if there is a single output, or
        a list of tensors if there are more than one outputs.
    """
    inputs = nest.flatten(inputs)
    if mask is None:
      masks = [None for _ in range(len(inputs))]
    else:
      masks = nest.flatten(mask)

    if context.in_graph_mode():
      # Try to retrieve cached outputs if the layer has already been called
      # on these exact inputs.
      cache_key = (layers_util.object_list_uid(inputs)
                   + '_' + layers_util.object_list_uid(masks))
      if cache_key in self._output_tensor_cache:
        # Cache hit.
        return self._output_tensor_cache[cache_key]
    # Actually apply the network graph to the new inputs.
    outputs, _ = self._run_internal_graph(inputs, masks)
    return outputs
开发者ID:AnddyWang,项目名称:tensorflow,代码行数:33,代码来源:network.py


示例5: _eager_metrics_fn

def _eager_metrics_fn(model, outputs, targets, sample_weights=None, masks=None):
  """Calculates the metrics for each output of the given model.

  Arguments:
      model: The model on which metrics are being calculated.
      outputs: The outputs of the given model.
      targets: The predictions or targets of the given model.
      sample_weights: Optional list of sample weights for each output.
      masks: Optional list of masks for each output.

  Returns:
      Returns the metric results for each output of the model.
  """
  outputs = nest.flatten(outputs)
  targets = nest.flatten(targets)
  # TODO(psv): Consider supporting skip target indices in eager mode?
  # Invoke all(weighted and unweighted) metrics.
  metric_results = []
  if targets:
    metric_results = model._handle_metrics(
        outputs,
        targets=targets,
        sample_weights=sample_weights,
        masks=masks,
        return_weighted_and_unweighted_metrics=True)

  # Add metric results from the `add_metric` metrics.
  metric_results.extend([
      m.result()
      for m in model.metrics
      if m not in model._compile_metric_functions
  ])
  return metric_results
开发者ID:aritratony,项目名称:tensorflow,代码行数:33,代码来源:training_eager.py


示例6: _get_cached_states

  def _get_cached_states(self, times):
    """Retrieve cached states for a batch of times."""
    read_chunk_numbers = self._get_chunk_number(times)
    looked_up_state = list(self._cached_states.lookup(
        math_ops.cast(read_chunk_numbers, dtypes.int64)))
    looked_up_state = tuple(looked_up_state)
    # We need to special-case the first chunk in a series to explicitly rely on
    # the model's starting state so that gradients flow back to it. Otherwise it
    # would affect only initialization, and would not be read from or updated
    # during training. Not doing this also isolates that part of the graph,
    # leading to errors on model reload if there are trainable variables
    # affecting a model's start state.
    if self._input_statistics is not None:
      start_time = self._input_statistics.start_time
    else:
      start_time = 0
    set_to_start_state = math_ops.equal(read_chunk_numbers,
                                        self._get_chunk_number(start_time))
    new_states = []
    for start_state_value, cache_variable in zip(
        nest.flatten(
            math_utils.replicate_state(self._start_state,
                                       array_ops.shape(times)[0])),
        nest.flatten(looked_up_state)):

      new_states.append(
          array_ops.where(set_to_start_state, start_state_value,
                          cache_variable))
    looked_up_state = nest.pack_sequence_as(looked_up_state, new_states)
    return looked_up_state
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:30,代码来源:state_management.py


示例7: _apply_exogenous_update

 def _apply_exogenous_update(
     self, current_times, step_number, state, raw_features,
     embedded_exogenous_regressors):
   """Performs a conditional state update based on exogenous features."""
   if embedded_exogenous_regressors is None:
     return state
   else:
     current_exogenous_regressors = embedded_exogenous_regressors[
         :, step_number, :]
     exogenous_updated_state = self._exogenous_input_step(
         current_times=current_times,
         current_exogenous_regressors=current_exogenous_regressors,
         state=state)
     if self._exogenous_update_condition is not None:
       current_raw_exogenous_features = {
           key: value[:, step_number] for key, value in raw_features.items()
           if key not in [PredictionFeatures.STATE_TUPLE,
                          TrainEvalFeatures.TIMES,
                          TrainEvalFeatures.VALUES]}
       conditionally_updated_state_flat = []
       for updated_state_element, original_state_element in zip(
           nest.flatten(exogenous_updated_state),
           nest.flatten(state)):
         conditionally_updated_state_flat.append(
             array_ops.where(
                 self._exogenous_update_condition(
                     times=current_times,
                     features=current_raw_exogenous_features),
                 updated_state_element,
                 original_state_element))
       return nest.pack_sequence_as(state, conditionally_updated_state_flat)
     else:
       return exogenous_updated_state
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:33,代码来源:model.py


示例8: _Update

def _Update(struct_acc, struct_x, t):
  """Updates t-th row in accumulators.

  Args:
    struct_acc: The accumulators. A structure of tensors.
    struct_x: The new values. A structure of tensors congruent to `struct_acc`.
    t: A scalar integer. Performance is better if `t` is on the device
      memory.

  Returns:
    A structure of tensors. Say, ret is a returned dictionary. Then, for
    each key, we have:
      ret[key] = struct_acc[key];
      ret[key][t, :] = struct_x[key]
  """
  to_skip_update = set()
  acc_lst = nest.flatten(struct_acc)
  x_lst = nest.flatten(struct_x)
  t = math_ops.to_int32([t])  # tf.to_int32 casts on-device tensors.
  lst = []
  for acc, x in zip(acc_lst, x_lst):
    if acc in to_skip_update:
      # Until b/62105730 is fixed, we need to avoid inplace update for tensors
      # of rank 1.  could reshape to handle it, but we don't really need the
      # values applied to these, so just skip their modification.
      lst += [acc]
    else:
      lst += [alias_inplace_update(acc, t, array_ops.expand_dims(x, 0))]
  return nest.pack_sequence_as(struct_acc, lst)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:29,代码来源:recurrent.py


示例9: testFlattenAndPack

  def testFlattenAndPack(self):
    structure = ((3, 4), 5, (6, 7, (9, 10), 8))
    flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
    self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
    self.assertEqual(
        nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
                                                 ("d", "e", ("f", "g"), "h")))
    point = collections.namedtuple("Point", ["x", "y"])
    structure = (point(x=4, y=2), ((point(x=1, y=0),),))
    flat = [4, 2, 1, 0]
    self.assertEqual(nest.flatten(structure), flat)
    restructured_from_flat = nest.pack_sequence_as(structure, flat)
    self.assertEqual(restructured_from_flat, structure)
    self.assertEqual(restructured_from_flat[0].x, 4)
    self.assertEqual(restructured_from_flat[0].y, 2)
    self.assertEqual(restructured_from_flat[1][0][0].x, 1)
    self.assertEqual(restructured_from_flat[1][0][0].y, 0)

    self.assertEqual([5], nest.flatten(5))
    self.assertEqual([np.array([5])], nest.flatten(np.array([5])))

    self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
    self.assertEqual(
        np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))

    with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
      nest.pack_sequence_as("scalar", [4, 5])

    with self.assertRaisesRegexp(TypeError, "flat_sequence"):
      nest.pack_sequence_as([4, 5], "bad_sequence")

    with self.assertRaises(ValueError):
      nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:33,代码来源:nest_test.py


示例10: _hierarchical_pad

    def _hierarchical_pad(input_, output, control):
      """Pad and flatten hierarchical inputs, outputs, and controls."""
      # Pad empty segments with end tokens and flatten hierarchy.
      input_ = nest.flatten(pad_with_element(
          input_, self._max_lengths[:-1],
          data.np_onehot([self.end_token], self.input_depth)))
      output = nest.flatten(pad_with_element(
          output, self._max_lengths[:-1],
          data.np_onehot([self.end_token], self.output_depth)))
      length = np.squeeze(np.array([len(x) for x in input_], np.int32))

      # Pad and concatenate flatten hierarchy.
      input_ = np.concatenate(
          [pad_with_value(x, self._max_lengths[-1], 0) for x in input_])
      output = np.concatenate(
          [pad_with_value(x, self._max_lengths[-1], 0) for x in output])

      if np.size(control):
        control = nest.flatten(pad_with_element(
            control, self._max_lengths[:-1],
            data.np_onehot(
                [self._control_pad_token], self.control_depth)))
        control = np.concatenate(
            [pad_with_value(x, self._max_lengths[-1], 0) for x in control])

      return input_, output, control, length
开发者ID:cghawthorne,项目名称:magenta,代码行数:26,代码来源:data_hierarchical.py


示例11: testNoProjNoShardingNestedTupleStateSaver

  def testNoProjNoShardingNestedTupleStateSaver(self):
    num_units = 3
    input_size = 5
    batch_size = 2
    max_length = 8
    with self.test_session(graph=tf.Graph()) as sess:
      initializer = tf.random_uniform_initializer(-0.01, 0.01, seed=self._seed)
      state_saver = TestStateSaver(batch_size, {"c0": num_units,
                                                "m0": num_units,
                                                "c1": num_units + 1,
                                                "m1": num_units + 1,
                                                "c2": num_units + 2,
                                                "m2": num_units + 2,
                                                "c3": num_units + 3,
                                                "m3": num_units + 3})
      def _cell(i):
        return tf.contrib.rnn.LSTMCell(
            num_units + i, use_peepholes=False, initializer=initializer,
            state_is_tuple=True)

      # This creates a state tuple which has 4 sub-tuples of length 2 each.
      cell = tf.contrib.rnn.MultiRNNCell(
          [_cell(i) for i in range(4)], state_is_tuple=True)

      self.assertEqual(len(cell.state_size), 4)
      for i in range(4):
        self.assertEqual(len(cell.state_size[i]), 2)

      inputs = max_length * [
          tf.placeholder(tf.float32, shape=(batch_size, input_size))]

      state_names = (("c0", "m0"), ("c1", "m1"),
                     ("c2", "m2"), ("c3", "m3"))
      with tf.variable_scope("share_scope"):
        outputs, state = tf.contrib.rnn.static_state_saving_rnn(
            cell, inputs, state_saver=state_saver, state_name=state_names)
      self.assertEqual(len(outputs), len(inputs))

      # Final output comes from _cell(3) which has state size num_units + 3
      for out in outputs:
        self.assertEqual(out.get_shape().as_list(), [batch_size, num_units + 3])

      tf.global_variables_initializer().run()
      input_value = np.random.randn(batch_size, input_size)
      last_states = sess.run(
          list(nest.flatten(state)), feed_dict={inputs[0]: input_value})
      saved_states = sess.run(
          list(state_saver.saved_state.values()),
          feed_dict={inputs[0]: input_value})
      self.assertEqual(8, len(last_states))
      self.assertEqual(8, len(saved_states))
      flat_state_names = nest.flatten(state_names)
      named_saved_states = dict(
          zip(state_saver.saved_state.keys(), saved_states))

      for i in range(8):
        self.assertAllEqual(
            last_states[i],
            named_saved_states[flat_state_names[i]])
开发者ID:Hwhitetooth,项目名称:tensorflow,代码行数:59,代码来源:core_rnn_test.py


示例12: body

        def body(time, elements_finished, current_input, emit_ta, state, loop_state):
            """Internal while loop body for raw_rnn.

      Args:
        time: time scalar.
        elements_finished: batch-size vector.
        current_input: possibly nested tuple of input tensors.
        emit_ta: possibly nested tuple of output TensorArrays.
        state: possibly nested tuple of state tensors.
        loop_state: possibly nested tuple of loop state tensors.

      Returns:
        Tuple having the same size as Args but with updated values.
      """
            (next_output, cell_state) = cell(current_input, state)

            nest.assert_same_structure(state, cell_state)
            nest.assert_same_structure(cell.output_size, next_output)

            next_time = time + 1
            (next_finished, next_input, next_state, emit_output, next_loop_state) = loop_fn(
                next_time, next_output, cell_state, loop_state
            )

            nest.assert_same_structure(state, next_state)
            nest.assert_same_structure(current_input, next_input)
            nest.assert_same_structure(emit_ta, emit_output)

            # If loop_fn returns None for next_loop_state, just reuse the
            # previous one.
            loop_state = loop_state if next_loop_state is None else next_loop_state

            def _copy_some_through(current, candidate):
                """Copy some tensors through via array_ops.where."""
                current_flat = nest.flatten(current)
                candidate_flat = nest.flatten(candidate)
                # pylint: disable=g-long-lambda,cell-var-from-loop
                result_flat = [
                    _on_device(
                        lambda: array_ops.where(elements_finished, current_i, candidate_i), device=candidate_i.op.device
                    )
                    for (current_i, candidate_i) in zip(current_flat, candidate_flat)
                ]
                # pylint: enable=g-long-lambda,cell-var-from-loop
                return nest.pack_sequence_as(structure=current, flat_sequence=result_flat)

            emit_output = _copy_some_through(zero_emit, emit_output)
            next_state = _copy_some_through(state, next_state)

            emit_output_flat = nest.flatten(emit_output)
            emit_ta_flat = nest.flatten(emit_ta)

            elements_finished = math_ops.logical_or(elements_finished, next_finished)

            emit_ta_flat = [ta.write(time, emit) for (ta, emit) in zip(emit_ta_flat, emit_output_flat)]

            emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=emit_ta_flat)

            return (next_time, elements_finished, next_input, emit_ta, next_state, loop_state)
开发者ID:ygoverdhan,项目名称:tensorflow,代码行数:59,代码来源:rnn.py


示例13: _run_targets

 def _run_targets(self, targets1, targets2=None, run_init=True):
   targets1 = nest.flatten(targets1)
   targets2 = ([] if targets2 is None else nest.flatten(targets2))
   assert len(targets1) == len(targets2) or not targets2
   if run_init:
     init = variables.global_variables_initializer()
     self.evaluate(init)
   return self.evaluate(targets1 + targets2)
开发者ID:LongJun123456,项目名称:tensorflow,代码行数:8,代码来源:control_flow_ops_test.py


示例14: testFlattenDictOrder

 def testFlattenDictOrder(self):
   """`flatten` orders dicts by key, including OrderedDicts."""
   ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
   plain = {"d": 3, "b": 1, "a": 0, "c": 2}
   ordered_flat = nest.flatten(ordered)
   plain_flat = nest.flatten(plain)
   self.assertEqual([0, 1, 2, 3], ordered_flat)
   self.assertEqual([0, 1, 2, 3], plain_flat)
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:8,代码来源:nest_test.py


示例15: _assert_same_shape

 def _assert_same_shape(input1, input2, double=False):
   flat_input1 = nest.flatten(input1)
   flat_input2 = nest.flatten(input2)
   for inp1, inp2 in zip(flat_input1, flat_input2):
     input_shape = inp1.get_shape().as_list()
     if double:
       input_shape[1] *= 2
     self.assertEqual(input_shape, inp2.get_shape().as_list())
开发者ID:Hwhitetooth,项目名称:tensorflow,代码行数:8,代码来源:core_rnn_test.py


示例16: state_barrier_context

def state_barrier_context(state):
  """Return a context manager that prevents interior ops from running
  unless the whole state has been computed.

  This is to prevent assign race conditions.
  """
  tensors = [x for x in nest.flatten(state) if type(x) == tf.Tensor]
  tarray = [x.flow for x in nest.flatten(state) if hasattr(x, "flow")]
  return tf.control_dependencies(tensors + tarray)
开发者ID:ALISCIFP,项目名称:models,代码行数:9,代码来源:utils.py


示例17: wrap_state

 def wrap_state(self, state):
     dummy = BeamDecoderCellWrapper(None, self.num_classes, self.max_len, self.stop_token, self.beam_size)
     if nest.is_sequence(state):
         batch_size = tf.shape(nest.flatten(state)[0])[0]
         dtype = nest.flatten(state)[0].dtype
     else:
         batch_size = tf.shape(state)[0]
         dtype = state.dtype
     return dummy._create_state(batch_size, dtype, cell_state=state)
开发者ID:anair13,项目名称:tensorflow,代码行数:9,代码来源:beam_decoder.py


示例18: _get_grads_lists_curvature_prop

 def _get_grads_lists_curvature_prop(self, tensors):
   loss_inputs = list(loss.inputs for loss in self._layers.losses)
   transformed_random_signs = self._get_transformed_random_signs()
   grads_flat = gradients_impl.gradients(
       nest.flatten(loss_inputs),
       nest.flatten(tensors),
       grad_ys=nest.flatten(transformed_random_signs))
   grads_all = nest.pack_sequence_as(tensors, grads_flat)
   return tuple((grad,) for grad in grads_all)
开发者ID:SylChan,项目名称:tensorflow,代码行数:9,代码来源:estimator.py


示例19: run_steps_on_dataset

  def run_steps_on_dataset(self, fn, iterator, iterations):
    # Enqueue ops
    shapes = nest.flatten(iterator.output_shapes)
    if any([not s.is_fully_defined() for s in shapes]):
      raise ValueError(
          'TPU currently requires fully defined shapes. Either use '
          'set_shape() on the input tensors or use '
          'dataset.apply(map_and_batch(..., drop_remainder=True)).')
    types = nest.flatten(iterator.output_types)

    def enqueue_ops_fn():
      """Enqueue ops for one iteration."""
      control_deps = []
      sharded_inputs = []
      with ops.device(self._host):
        for _ in range(self._num_cores_per_host):
          # Use control dependencies to ensure a deterministic ordering.
          with ops.control_dependencies(control_deps):
            inputs = nest.flatten(iterator.get_next())
            control_deps.extend(inputs)
            sharded_inputs.append(inputs)

      enqueue_ops = []
      for core_id, shard_input in enumerate(sharded_inputs):
        enqueue_ops.append(
            tpu_ops.infeed_enqueue_tuple(
                inputs=shard_input, shapes=shapes, device_ordinal=core_id))
      return enqueue_ops

    def enqueue_ops_loop_body(i):
      with ops.control_dependencies(enqueue_ops_fn()):
        return i + 1

    with ops.device(self._host):
      enqueue_ops = control_flow_ops.while_loop(
          lambda i: i < iterations,
          enqueue_ops_loop_body,
          [constant_op.constant(0)],
          parallel_iterations=1)

    # Dequeue ops
    def dequeue_fn():
      dequeued = tpu.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
      return nest.pack_sequence_as(iterator.output_shapes, dequeued)

    # Wrap `fn` for repeat.
    run_fn = lambda: fn(dequeue_fn())

    # Repeat
    def iterate_on_tpu():
      return tpu.repeat(iterations, run_fn, [])

    # Re-write and distribute computation.
    tpu_result = tpu.batch_parallel(
        iterate_on_tpu, [], num_shards=self._num_cores_per_host)

    return control_flow_ops.group(tpu_result, enqueue_ops)
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:57,代码来源:tpu_strategy.py


示例20: _check_default_value

def _check_default_value(shape, default_value, dtype, key):
  """Returns default value as tuple if it's valid, otherwise raises errors.

  This function verifies that `default_value` is compatible with both `shape`
  and `dtype`. If it is not compatible, it raises an error. If it is compatible,
  it casts default_value to a tuple and returns it. `key` is used only
  for error message.

  Args:
    shape: An iterable of integers specifies the shape of the `Tensor`.
    default_value: If a single value is provided, the same value will be applied
      as the default value for every item. If an iterable of values is
      provided, the shape of the `default_value` should be equal to the given
      `shape`.
    dtype: defines the type of values. Default value is `tf.float32`. Must be a
      non-quantized, real integer or floating point type.
    key: A string providing key to look up corresponding `Tensor`.

  Returns:
    A tuple which will be used as default value.

  Raises:
    TypeError: if `default_value` is an iterable but not compatible with `shape`
    TypeError: if `default_value` is not compatible with `dtype`.
    ValueError: if `dtype` is not convertible to `tf.float32`.
  """
  if default_value is None:
    return None

  if isinstance(default_value, int):
    return _create_tuple(shape, default_value)

  if isinstance(default_value, float) and dtype.is_floating:
    return _create_tuple(shape, default_value)

  if callable(getattr(default_value, 'tolist', None)):  # Handles numpy arrays
    default_value = default_value.tolist()

  if nest.is_sequence(default_value):
    if not _is_shape_and_default_value_compatible(default_value, shape):
      raise ValueError(
          'The shape of default_value must be equal to given shape. '
          'default_value: {}, shape: {}, key: {}'.format(
              default_value, shape, key))
    # Check if the values in the list are all integers or are convertible to
    # floats.
    is_list_all_int = all(
        isinstance(v, int) for v in nest.flatten(default_value))
    is_list_has_float = any(
        isinstance(v, float) for v in nest.flatten(default_value))
    if is_list_all_int:
      return _as_tuple(default_value)
    if is_list_has_float and dtype.is_floating:
      return _as_tuple(default_value)
  raise TypeError('default_value must be compatible with dtype. '
                  'default_value: {}, dtype: {}, key: {}'.format(
                      default_value, dtype, key))
开发者ID:finardi,项目名称:tensorflow,代码行数:57,代码来源:feature_column.py



注:本文中的tensorflow.python.util.nest.flatten函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python nest.is_sequence函数代码示例发布时间:2022-05-27
下一篇:
Python nest.assert_same_structure函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap