• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python data_flow_ops.dynamic_stitch函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.ops.data_flow_ops.dynamic_stitch函数的典型用法代码示例。如果您正苦于以下问题:Python dynamic_stitch函数的具体用法?Python dynamic_stitch怎么用?Python dynamic_stitch使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了dynamic_stitch函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testErrorIndicesMultiDimensional

 def testErrorIndicesMultiDimensional(self):
   indices = [
       constant_op.constant([0, 4, 7]), constant_op.constant([[1, 6, 2, 3, 5]])
   ]
   data = [
       constant_op.constant([[0, 40, 70]]),
       constant_op.constant([10, 60, 20, 30, 50])
   ]
   with self.assertRaises(ValueError):
     data_flow_ops.dynamic_stitch(indices, data)
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:10,代码来源:dynamic_stitch_op_test.py


示例2: testErrorDataDimSizeMismatch

 def testErrorDataDimSizeMismatch(self):
   indices = [
       constant_op.constant([0, 4, 5]), constant_op.constant([1, 6, 2, 3])
   ]
   data = [
       constant_op.constant([[0], [40], [70]]),
       constant_op.constant([[10, 11], [60, 61], [20, 21], [30, 31]])
   ]
   with self.assertRaises(ValueError):
     data_flow_ops.dynamic_stitch(indices, data)
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:10,代码来源:dynamic_stitch_op_test.py


示例3: testErrorDataAndIndicesSizeMismatch

 def testErrorDataAndIndicesSizeMismatch(self):
   indices = [
       constant_op.constant([0, 4, 7]), constant_op.constant([1, 6, 2, 3, 5])
   ]
   data = [
       constant_op.constant([0, 40, 70]),
       constant_op.constant([10, 60, 20, 30])
   ]
   with self.assertRaises(ValueError):
     data_flow_ops.dynamic_stitch(indices, data)
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:10,代码来源:dynamic_stitch_op_test.py


示例4: testHigherRankGPU

 def testHigherRankGPU(self):
   with self.cached_session() as sess:
     indices = [
         constant_op.constant(6),
         constant_op.constant([4, 1]),
         constant_op.constant([[5, 2], [0, 3]])
     ]
     data = [
         constant_op.constant([61, 62], dtype=dtypes.float32),
         constant_op.constant([[41, 42], [11, 12]], dtype=dtypes.float32),
         constant_op.constant(
             [[[51, 52], [21, 22]], [[1, 2], [31, 32]]], dtype=dtypes.float32)
     ]
     stitched_t = data_flow_ops.dynamic_stitch(indices, data)
     stitched_val = self.evaluate(stitched_t)
     correct = 10 * np.arange(7)[:, None] + [1.0, 2.0]
     self.assertAllEqual(correct, stitched_val)
     self.assertEqual([7, 2], stitched_t.get_shape().as_list())
     # Test gradients
     stitched_grad = 7 * stitched_val
     grads = gradients_impl.gradients(stitched_t, indices + data,
                                      stitched_grad)
     self.assertEqual(grads[:3], [None] * 3)  # Indices have no gradients
     for datum, grad in zip(data, sess.run(grads[3:])):
       self.assertAllEqual(7.0 * self.evaluate(datum), grad)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:25,代码来源:dynamic_stitch_op_test.py


示例5: lookup

  def lookup(self, keys, name=None):
    if keys.dtype != self._key_dtype:
      raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' %
                      (self._key_dtype, keys.dtype))
    self._check_keys(keys)
    num_shards = self._num_shards
    if num_shards == 1:
      return self._table_shards[0].lookup(keys, name=name)

    shard_indices = self._shard_indices(keys)
    # TODO(andreasst): support 'keys' that are not vectors
    key_shards = data_flow_ops.dynamic_partition(keys, shard_indices,
                                                 num_shards)
    value_shards = [
        self._table_shards[i].lookup(key_shards[i], name=name)
        for i in range(num_shards)
    ]

    num_keys = keys.get_shape().dims[0]
    original_indices = math_ops.range(num_keys)
    partitioned_indices = data_flow_ops.dynamic_partition(original_indices,
                                                          shard_indices,
                                                          num_shards)
    result = data_flow_ops.dynamic_stitch(partitioned_indices, value_shards)
    result.set_shape(
        tensor_shape.TensorShape([num_keys]).concatenate(self._value_shape))
    return result
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:27,代码来源:sharded_mutable_dense_hashtable.py


示例6: DynamicStitchGrads

    def DynamicStitchGrads(op, grad):
        num_values = len(op.inputs) // 2
        indices_grad = [None] * num_values

        def AsInt32(x):
            return (x if op.inputs[0].dtype == dtypes.int32 else
                    math_ops.cast(x, dtypes.int32))

        idxs = [AsInt32(array_ops.reshape(op.inputs[i], (-1,)))
                for i in range(num_values)]
        if isinstance(grad, ops.IndexedSlices):
            output_shape = array_ops.shape(op.outputs[0])
            output_rows = output_shape[0]
            grad = math_ops.unsorted_segment_sum(grad.values, grad.indices,
                                                 output_rows)

        values_grad = []
        zeros = array_ops.zeros_like(grad)
        idx_zeros = [zeros[:array_ops.shape(x)[0]] for x in idxs]
        grad_range = math_ops.range(array_ops.shape(grad)[0])
        for i in range(num_values):
            if i == num_values - 1:
                v_grad = grad
            else:
                v_grad = data_flow_ops.dynamic_stitch(
                    [grad_range] + idxs[i + 1:], [grad] + idx_zeros[i + 1:])
            v_grad = array_ops.gather(v_grad, AsInt32(op.inputs[i]))
            values_grad += [v_grad]

        return indices_grad + values_grad
开发者ID:nengo,项目名称:nengo_deeplearning,代码行数:30,代码来源:tensorflow_patch.py


示例7: testPinRequiredOpsOnCPU

 def testPinRequiredOpsOnCPU(self):
     with ops.Graph().as_default() as g, g.device(graph_util.pin_variables_on_cpu):
         const_a = constant_op.constant(5.0)
         const_b = constant_op.constant(10.0)
         add_c = const_a + const_b
         var_v = state_ops.variable_op([], dtype=types.float32)
         assign_c_to_v = state_ops.assign(var_v, add_c)
         dynamic_stitch_int_result = data_flow_ops.dynamic_stitch([[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]])
         dynamic_stitch_float_result = data_flow_ops.dynamic_stitch(
             [[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]]
         )
         # Non-variable ops shuld not specify a device
         self.assertEqual(const_a.device, None)
         self.assertEqual(const_b.device, None)
         self.assertEqual(add_c.device, None)
         # Variable ops specify a device
         self.assertEqual(var_v.device, "/device:CPU:0")
         self.assertEqual(assign_c_to_v.device, "/device:CPU:0")
开发者ID:sumodm,项目名称:tensorflow,代码行数:18,代码来源:graph_util_test.py


示例8: testScalarGPU

 def testScalarGPU(self):
   indices = [constant_op.constant(0), constant_op.constant(1)]
   data = [constant_op.constant(40.0), constant_op.constant(60.0)]
   for step in -1, 1:
     stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data)
     stitched_val = self.evaluate(stitched_t)
     self.assertAllEqual([40.0, 60.0][::step], stitched_val)
     # Dimension 0 is max(flatten(indices))+1.
     self.assertEqual([2], stitched_t.get_shape().as_list())
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:9,代码来源:dynamic_stitch_op_test.py


示例9: testSumGradArgs

 def testSumGradArgs(self):
   with self.test_session(use_gpu=False):
     indices = [
         ops.convert_to_tensor([0, 1, 2, 3]), ops.convert_to_tensor([2, 3])
     ]
     values = [
         ops.convert_to_tensor([2, 3, 5, 7]), ops.convert_to_tensor([1, 1])
     ]
     self.assertAllEqual(
         data_flow_ops.dynamic_stitch(indices, values).eval(), [2, 3, 1, 1])
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:10,代码来源:embedding_ops_test.py


示例10: testInt32Gpu

 def testInt32Gpu(self):
   with self.test_session(use_gpu=True):
     indices = [
         ops.convert_to_tensor([0, 1, 2]), ops.convert_to_tensor([2, 3])
     ]
     values = [
         ops.convert_to_tensor([12, 23, 34]), ops.convert_to_tensor([1, 2])
     ]
     self.assertAllEqual(
         data_flow_ops.dynamic_stitch(indices, values).eval(), [12, 23, 1, 2])
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:10,代码来源:embedding_ops_test.py


示例11: testPinToCpu

 def testPinToCpu(self):
   with ops.Graph().as_default() as g, g.device(graph_util.pin_to_cpu):
     const_a = constant_op.constant(5.0)
     const_b = constant_op.constant(10.0)
     add_c = const_a + const_b
     var_v = state_ops.variable_op([], dtype=dtypes.float32)
     assign_c_to_v = state_ops.assign(var_v, add_c)
     const_string = constant_op.constant("on a cpu")
     dynamic_stitch_int_result = data_flow_ops.dynamic_stitch(
         [[0, 1, 2], [2, 3]], [[12, 23, 34], [1, 2]])
     dynamic_stitch_float_result = data_flow_ops.dynamic_stitch(
         [[0, 1, 2], [2, 3]], [[12.0, 23.0, 34.0], [1.0, 2.0]])
   self.assertDeviceEqual(const_a.device, "/device:CPU:0")
   self.assertDeviceEqual(const_b.device, "/device:CPU:0")
   self.assertDeviceEqual(add_c.device, "/device:CPU:0")
   self.assertDeviceEqual(var_v.device, "/device:CPU:0")
   self.assertDeviceEqual(assign_c_to_v.device, "/device:CPU:0")
   self.assertDeviceEqual(const_string.device, "/device:CPU:0")
   self.assertDeviceEqual(dynamic_stitch_int_result.device, "/device:CPU:0")
   self.assertDeviceEqual(dynamic_stitch_float_result.device, "/device:CPU:0")
开发者ID:manipopopo,项目名称:tensorflow,代码行数:20,代码来源:graph_util_test.py


示例12: testStitchOrder

 def testStitchOrder(self):
   with self.test_session():
     indices = []
     np_values = []
     values = []
     for _ in range(10):
       indices.extend([ops.convert_to_tensor(np.arange(100).astype(np.int32))])
       np_values.extend([np.random.uniform(size=100)])
       values.extend([ops.convert_to_tensor(np_values[-1])])
     stitched = data_flow_ops.dynamic_stitch(indices, values).eval()
   self.assertAllEqual(np_values[-1], stitched)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:11,代码来源:embedding_ops_test.py


示例13: testOneListOneDimensional

 def testOneListOneDimensional(self):
   with self.test_session():
     indices = [constant_op.constant([1, 6, 2, 3, 5, 0, 4, 7])]
     data = [constant_op.constant([10, 60, 20, 30, 50, 0, 40, 70])]
     stitched_t = data_flow_ops.dynamic_stitch(indices, data)
     stitched_val = stitched_t.eval()
     self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val)
     # Dimension 0 is determined by the max index in indices, so we
     # can only infer that the output is a vector of some unknown
     # length.
     self.assertEqual([None], stitched_t.get_shape().as_list())
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:11,代码来源:dynamic_stitch_op_test.py


示例14: testScalar

 def testScalar(self):
   with self.test_session():
     indices = [constant_op.constant(0), constant_op.constant(1)]
     data = [constant_op.constant(40), constant_op.constant(60)]
     for step in -1, 1:
       stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data)
       stitched_val = stitched_t.eval()
       self.assertAllEqual([40, 60][::step], stitched_val)
       # Dimension 0 is determined by the max index in indices, so we
       # can only infer that the output is a vector of some unknown
       # length.
       self.assertEqual([None], stitched_t.get_shape().as_list())
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:12,代码来源:dynamic_stitch_op_test.py


示例15: _ReductionGradAssist

def _ReductionGradAssist(op):
    """Reduction grads have much in common, so factor the commonality out."""
    inp = op.inputs[0]  # Example:
    input_shape = array_ops.shape(inp)  # [2, 3, 5, 7]
    input_rank = array_ops.rank(inp)  # 4
    indices = op.inputs[1]  # [1, 2]
    indices_shape = array_ops.shape(indices)  # [2]
    new_output_shape = data_flow_ops.dynamic_stitch(  # [2, 1, 1, 7]
        [math_ops.range(input_rank), indices],  # [0, 1, 2, 3]  # [1, 2]
        [input_shape, array_ops.fill(indices_shape, 1)],  # [2, 3, 5, 7]
    )  # [1, 1]
    return inp, new_output_shape, input_shape
开发者ID:adeelzaman,项目名称:tensorflow,代码行数:12,代码来源:math_grad.py


示例16: _DynamicPartitionGrads

def _DynamicPartitionGrads(op, *grads):
  """Gradients for DynamicPartition."""
  data = op.inputs[0]
  indices = op.inputs[1]
  num_partitions = op.get_attr("num_partitions")

  prefix_shape = array_ops.shape(indices)
  original_indices = array_ops.reshape(
      math_ops.range(math_ops.reduce_prod(prefix_shape)), prefix_shape)
  partitioned_indices = data_flow_ops.dynamic_partition(
      original_indices, indices, num_partitions)
  reconstructed = data_flow_ops.dynamic_stitch(partitioned_indices, grads)
  reconstructed = array_ops.reshape(reconstructed, array_ops.shape(data))
  return [reconstructed, None]
开发者ID:13331151,项目名称:tensorflow,代码行数:14,代码来源:data_flow_grad.py


示例17: testSimpleTwoDimensional

 def testSimpleTwoDimensional(self):
   with self.test_session():
     indices = [
         constant_op.constant([0, 4, 7]), constant_op.constant([1, 6]),
         constant_op.constant([2, 3, 5])
     ]
     data = [
         constant_op.constant([[0, 1], [40, 41], [70, 71]]),
         constant_op.constant([[10, 11], [60, 61]]),
         constant_op.constant([[20, 21], [30, 31], [50, 51]])
     ]
     stitched_t = data_flow_ops.dynamic_stitch(indices, data)
     stitched_val = stitched_t.eval()
     self.assertAllEqual([[0, 1], [10, 11], [20, 21], [30, 31], [40, 41],
                          [50, 51], [60, 61], [70, 71]], stitched_val)
     # Dimension 0 is determined by the max index in indices, so we
     # can only infer that the output is a matrix with 2 columns and
     # some unknown number of rows.
     self.assertEqual([None, 2], stitched_t.get_shape().as_list())
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:19,代码来源:dynamic_stitch_op_test.py


示例18: _AssertDynamicStitchResultIs

  def _AssertDynamicStitchResultIs(self, indices, data, expected):
    with self.test_session() as session:
      index_placeholders = [
          array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in indices
      ]
      data_placeholders = [
          array_ops.placeholder(dtypes.as_dtype(arg.dtype)) for arg in data
      ]
      with self.test_scope():
        output = data_flow_ops.dynamic_stitch(index_placeholders,
                                              data_placeholders)

      feed_dict = {}
      for placeholder, value in zip(index_placeholders, indices):
        feed_dict[placeholder] = value
      for placeholder, value in zip(data_placeholders, data):
        feed_dict[placeholder] = value
      result = session.run(output, feed_dict=feed_dict)
      self.assertAllClose(expected, result, rtol=1e-3)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:19,代码来源:dynamic_stitch_test.py


示例19: _sample_n

  def _sample_n(self, n, seed=None):
    with ops.control_dependencies(self._assertions):
      n = ops.convert_to_tensor(n, name="n")
      static_n = tensor_util.constant_value(n)
      n = int(static_n) if static_n is not None else n
      cat_samples = self.cat.sample(n, seed=seed)

      static_samples_shape = cat_samples.get_shape()
      if static_samples_shape.is_fully_defined():
        samples_shape = static_samples_shape.as_list()
        samples_size = static_samples_shape.num_elements()
      else:
        samples_shape = array_ops.shape(cat_samples)
        samples_size = array_ops.size(cat_samples)
      static_batch_shape = self.get_batch_shape()
      if static_batch_shape.is_fully_defined():
        batch_shape = static_batch_shape.as_list()
        batch_size = static_batch_shape.num_elements()
      else:
        batch_shape = self.batch_shape()
        batch_size = array_ops.reduce_prod(batch_shape)
      static_event_shape = self.get_event_shape()
      if static_event_shape.is_fully_defined():
        event_shape = np.array(static_event_shape.as_list(), dtype=np.int32)
      else:
        event_shape = self.event_shape()

      # Get indices into the raw cat sampling tensor.  We will
      # need these to stitch sample values back out after sampling
      # within the component partitions.
      samples_raw_indices = array_ops.reshape(
          math_ops.range(0, samples_size), samples_shape)

      # Partition the raw indices so that we can use
      # dynamic_stitch later to reconstruct the samples from the
      # known partitions.
      partitioned_samples_indices = data_flow_ops.dynamic_partition(
          data=samples_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)

      # Copy the batch indices n times, as we will need to know
      # these to pull out the appropriate rows within the
      # component partitions.
      batch_raw_indices = array_ops.reshape(
          array_ops.tile(math_ops.range(0, batch_size), [n]), samples_shape)

      # Explanation of the dynamic partitioning below:
      #   batch indices are i.e., [0, 1, 0, 1, 0, 1]
      # Suppose partitions are:
      #     [1 1 0 0 1 1]
      # After partitioning, batch indices are cut as:
      #     [batch_indices[x] for x in 2, 3]
      #     [batch_indices[x] for x in 0, 1, 4, 5]
      # i.e.
      #     [1 1] and [0 0 0 0]
      # Now we sample n=2 from part 0 and n=4 from part 1.
      # For part 0 we want samples from batch entries 1, 1 (samples 0, 1),
      # and for part 1 we want samples from batch entries 0, 0, 0, 0
      #   (samples 0, 1, 2, 3).
      partitioned_batch_indices = data_flow_ops.dynamic_partition(
          data=batch_raw_indices,
          partitions=cat_samples,
          num_partitions=self.num_components)
      samples_class = [None for _ in range(self.num_components)]

      for c in range(self.num_components):
        n_class = array_ops.size(partitioned_samples_indices[c])
        seed = distribution_util.gen_new_seed(seed, "mixture")
        samples_class_c = self.components[c].sample(n_class, seed=seed)

        # Pull out the correct batch entries from each index.
        # To do this, we may have to flatten the batch shape.

        # For sample s, batch element b of component c, we get the
        # partitioned batch indices from
        # partitioned_batch_indices[c]; and shift each element by
        # the sample index.  The final lookup can be thought of as
        # a matrix gather along locations (s, b) in
        # samples_class_c where the n_class rows correspond to
        # samples within this component and the batch_size columns
        # correspond to batch elements within the component.
        #
        # Thus the lookup index is
        #   lookup[c, i] = batch_size * s[i] + b[c, i]
        # for i = 0 ... n_class[c] - 1.
        lookup_partitioned_batch_indices = (
            batch_size * math_ops.range(n_class) +
            partitioned_batch_indices[c])
        samples_class_c = array_ops.reshape(
            samples_class_c,
            array_ops.concat(([n_class * batch_size], event_shape), 0))
        samples_class_c = array_ops.gather(
            samples_class_c, lookup_partitioned_batch_indices,
            name="samples_class_c_gather")
        samples_class[c] = samples_class_c

      # Stitch back together the samples across the components.
      lhs_flat_ret = data_flow_ops.dynamic_stitch(
          indices=partitioned_samples_indices, data=samples_class)
#.........这里部分代码省略.........
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:101,代码来源:mixture.py


示例20: embedding_lookup


#.........这里部分代码省略.........
      else:
        ndims = array_ops.size(array_ops.shape(x))
      return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims)))
    return x
  with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
    np = len(params)  # Number of partitions
    params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
    if np == 1:
      with ops.colocate_with(params[0]):
        # TODO(apassos): implement the sharded version as well.
        if isinstance(params[0], resource_variable_ops.ResourceVariable):
          ret = params[0].sparse_read(ids, name=name)
        else:
          ret = array_ops.gather(params[0], ids, name=name,
                                 validate_indices=validate_indices)
      return maybe_normalize(ret)
    else:
      ids = ops.convert_to_tensor(ids, name="ids")
      flat_ids = array_ops.reshape(ids, [-1])
      original_indices = math_ops.range(array_ops.size(flat_ids))

      # Create p_assignments and set new_ids depending on the strategy.
      if partition_strategy == "mod":
        p_assignments = flat_ids % np
        new_ids = flat_ids // np
      elif partition_strategy == "div":
        # Compute num_total_ids as the sum of dim-0 of params, then assign to
        # partitions based on a constant number of ids per partition. Optimize
        # if we already know the full shape statically.
        dim_0_size = params[0].get_shape()[0]
        for p in xrange(1, np):
          dim_0_size += params[p].get_shape()[0]
        if dim_0_size.value:
          num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
        else:
          dim_0_sizes = []
          for p in xrange(np):
            if params[p].get_shape()[0].value is not None:
              dim_0_sizes.append(params[p].get_shape()[0].value)
            else:
              with ops.colocate_with(params[p]):
                dim_0_sizes.append(array_ops.shape(params[p])[0])
          num_total_ids = math_ops.reduce_sum(
              math_ops.cast(array_ops.pack(dim_0_sizes), flat_ids.dtype))
        ids_per_partition = num_total_ids // np
        extras = num_total_ids % np

        p_assignments = math_ops.maximum(
            flat_ids // (ids_per_partition + 1),
            (flat_ids - extras) // ids_per_partition)

        # Emulate a conditional using a boolean indicator tensor
        is_in_first_extras_partitions = math_ops.cast(
            p_assignments < extras, flat_ids.dtype)
        new_ids = (
            is_in_first_extras_partitions * (
                flat_ids % (ids_per_partition + 1)) +
            (1 - is_in_first_extras_partitions) * (
                (flat_ids - extras) % ids_per_partition))
      else:
        raise ValueError("Unrecognized partition strategy: " +
                         partition_strategy)

      # Cast partition assignments to int32 for use in dynamic_partition.
      # There really should not be more than 2^32 partitions.
      p_assignments = math_ops.cast(p_assignments, dtypes.int32)
      # Partition list of ids based on assignments into np separate lists
      gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
      # Similarly, partition the original indices.
      pindices = data_flow_ops.dynamic_partition(original_indices,
                                                 p_assignments, np)
      # Do np separate lookups, finding embeddings for plist[p] in params[p]
      partitioned_result = []
      for p in xrange(np):
        with ops.colocate_with(params[p]):
          partitioned_result.append(array_ops.gather(
              params[p], gather_ids[p],
              validate_indices=validate_indices))
      # Stitch these back together
      ret = data_flow_ops.dynamic_stitch(pindices, partitioned_result,
                                         name=name)
      # Reshape to reverse the flattening of ids.
      element_shape = params[0].get_shape()[1:]
      for p in params[1:]:
        element_shape = element_shape.merge_with(p.get_shape()[1:])
      if element_shape.is_fully_defined():
        ret = array_ops.reshape(ret, array_ops.concat(0, [
            array_ops.shape(ids), element_shape]))
      else:
        # It's important that we compute params[0].shape on the right device
        # to avoid data motion.
        with ops.colocate_with(params[0]):
          params_shape = array_ops.shape(params[0])
        ret = array_ops.reshape(ret, array_ops.concat(0, [
            array_ops.shape(ids), array_ops.slice(params_shape, [1], [-1])]))
      # output shape = ids.shape + params[*].shape[1:]
      # Normally the reshape is sufficient, but setting shape explicitly
      # teaches shape inference that params[1:].get_shape() matters.
      ret.set_shape(ids.get_shape().concatenate(element_shape))
      return maybe_normalize(ret)
开发者ID:HKUST-SING,项目名称:tensorflow,代码行数:101,代码来源:embedding_ops.py



注:本文中的tensorflow.python.ops.data_flow_ops.dynamic_stitch函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python data_flow_ops.initialize_all_tables函数代码示例发布时间:2022-05-27
下一篇:
Python data_flow_ops.dynamic_partition函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap