• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python list_ops.tensor_list_stack函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.ops.list_ops.tensor_list_stack函数的典型用法代码示例。如果您正苦于以下问题:Python tensor_list_stack函数的具体用法?Python tensor_list_stack怎么用?Python tensor_list_stack使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了tensor_list_stack函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testStackEmptyList

  def testStackEmptyList(self, max_num_elements):
    # Should be able to stack empty lists with fully defined element_shape.
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=[1, 2],
        max_num_elements=max_num_elements)
    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t).shape, (0, 1, 2))

    # Should not be able to stack empty lists with partially defined
    # element_shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "non-fully-defined"):
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=[None, 2],
          max_num_elements=max_num_elements)
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)

    # Should not be able to stack empty lists with undefined element_shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "non-fully-defined"):
      l = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=None,
          max_num_elements=max_num_elements)
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)
开发者ID:aeverall,项目名称:tensorflow,代码行数:29,代码来源:list_ops_test.py


示例2: testConcat

  def testConcat(self):
    c = constant_op.constant([1.0, 2.0], dtype=dtypes.float32)
    l0 = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
    l1 = list_ops.tensor_list_from_tensor([-1.0], element_shape=scalar_shape())
    l_batch_0 = array_ops.stack([l0, l1])
    l_batch_1 = array_ops.stack([l1, l0])

    l_concat_01 = list_ops.tensor_list_concat_lists(
        l_batch_0, l_batch_1, element_dtype=dtypes.float32)
    l_concat_10 = list_ops.tensor_list_concat_lists(
        l_batch_1, l_batch_0, element_dtype=dtypes.float32)
    l_concat_00 = list_ops.tensor_list_concat_lists(
        l_batch_0, l_batch_0, element_dtype=dtypes.float32)
    l_concat_11 = list_ops.tensor_list_concat_lists(
        l_batch_1, l_batch_1, element_dtype=dtypes.float32)

    expected_00 = [[1.0, 2.0, 1.0, 2.0], [-1.0, -1.0]]
    expected_01 = [[1.0, 2.0, -1.0], [-1.0, 1.0, 2.0]]
    expected_10 = [[-1.0, 1.0, 2.0], [1.0, 2.0, -1.0]]
    expected_11 = [[-1.0, -1.0], [1.0, 2.0, 1.0, 2.0]]

    for i, (concat, expected) in enumerate(zip(
        [l_concat_00, l_concat_01, l_concat_10, l_concat_11],
        [expected_00, expected_01, expected_10, expected_11])):
      splitted = array_ops.unstack(concat)
      splitted_stacked_ret = self.evaluate(
          (list_ops.tensor_list_stack(splitted[0], dtypes.float32),
           list_ops.tensor_list_stack(splitted[1], dtypes.float32)))
      print("Test concat %d: %s, %s, %s, %s"
            % (i, expected[0], splitted_stacked_ret[0],
               expected[1], splitted_stacked_ret[1]))
      self.assertAllClose(expected[0], splitted_stacked_ret[0])
      self.assertAllClose(expected[1], splitted_stacked_ret[1])

    # Concatenating mismatched shapes fails.
    with self.assertRaises((errors.InvalidArgumentError, ValueError)):
      self.evaluate(
          list_ops.tensor_list_concat_lists(
              l_batch_0,
              list_ops.empty_tensor_list(scalar_shape(), dtypes.float32),
              element_dtype=dtypes.float32))

    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 "element shapes are not identical at index 0"):
      l_batch_of_vec_tls = array_ops.stack(
          [list_ops.tensor_list_from_tensor([[1.0]], element_shape=[1])] * 2)
      self.evaluate(
          list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_vec_tls,
                                            element_dtype=dtypes.float32))

    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 r"input_b\[0\].dtype != element_dtype."):
      l_batch_of_int_tls = array_ops.stack(
          [list_ops.tensor_list_from_tensor([1], element_shape=scalar_shape())]
          * 2)
      self.evaluate(
          list_ops.tensor_list_concat_lists(l_batch_0, l_batch_of_int_tls,
                                            element_dtype=dtypes.float32))
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:58,代码来源:list_ops_test.py


示例3: testAddTensorListsFailsIfLeadingDimsMismatch

 def testAddTensorListsFailsIfLeadingDimsMismatch(self):
   with self.cached_session(), self.test_scope():
     l1 = list_ops.tensor_list_reserve(
         element_shape=[], element_dtype=dtypes.float32, num_elements=2)
     l2 = list_ops.tensor_list_reserve(
         element_shape=[], element_dtype=dtypes.float32, num_elements=3)
     l = math_ops.add_n([l1, l2])
     with self.assertRaisesRegexp(
         errors.InvalidArgumentError,
         "TensorList arguments to AddN must all have the same shape"):
       list_ops.tensor_list_stack(l, element_dtype=dtypes.float32).eval()
开发者ID:Albert-Z-Guo,项目名称:tensorflow,代码行数:11,代码来源:add_n_test.py


示例4: test_tensor_list_empty_list

  def test_tensor_list_empty_list(self):
    l = special_functions.tensor_list([],
                                      element_dtype=dtypes.int32,
                                      element_shape=())
    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
    with self.cached_session() as sess:
      self.assertAllEqual(sess.run(sl), [])

    l = special_functions.tensor_list((),
                                      element_dtype=dtypes.int32,
                                      element_shape=())
    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
    with self.cached_session() as sess:
      self.assertAllEqual(sess.run(sl), [])
开发者ID:abhinav-upadhyay,项目名称:tensorflow,代码行数:14,代码来源:special_functions_test.py


示例5: testStackWithPartiallyDefinedElementShape

  def testStackWithPartiallyDefinedElementShape(self):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32, element_shape=[-1])
    l = list_ops.tensor_list_push_back(l, constant_op.constant([1.0]))
    l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0]))

    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [[1.0], [2.0]])

    # Should raise an error when the element tensors do not all have the same
    # shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"):
      l = list_ops.tensor_list_push_back(l, constant_op.constant([2.0, 3.0]))
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)
开发者ID:becster,项目名称:tensorflow,代码行数:15,代码来源:list_ops_test.py


示例6: test_tf_tensor_list_new_empty

 def test_tf_tensor_list_new_empty(self):
   l = data_structures.tf_tensor_list_new([],
                                          element_dtype=dtypes.int32,
                                          element_shape=())
   t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
   with self.cached_session() as sess:
     self.assertAllEqual(sess.run(t), [])
开发者ID:bunbutter,项目名称:tensorflow,代码行数:7,代码来源:data_structures_test.py


示例7: test_tensor_list_from_elements

  def test_tensor_list_from_elements(self):
    elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]

    l = special_functions.tensor_list(elements)
    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
    with self.test_session() as sess:
      self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
开发者ID:AnishShah,项目名称:tensorflow,代码行数:7,代码来源:special_functions_test.py


示例8: testStack

 def testStack(self):
   l = list_ops.empty_tensor_list(element_dtype=dtypes.float32,
                                  element_shape=scalar_shape())
   l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
   l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
   t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
   self.assertAllEqual(t, [1.0, 2.0])
开发者ID:andrewharp,项目名称:tensorflow,代码行数:7,代码来源:list_ops_test.py


示例9: test_list_pop

  def test_list_pop(self):

    def test_fn():
      l = [1, 2, 3]
      utils.set_element_type(l, dtypes.int32, ())
      s = l.pop()
      return s, l

    node = self.parse_and_analyze(
        test_fn,
        {
            'utils': utils,
            'dtypes': dtypes
        },
        include_type_analysis=True,
    )
    node = lists.transform(node, self.ctx)

    with self.compiled(node) as result:
      result.utils = utils
      result.dtypes = dtypes
      with self.test_session() as sess:
        ts, tl = result.test_fn()
        r = list_ops.tensor_list_stack(tl, dtypes.int32)
        self.assertAllEqual(sess.run(r), [1, 2])
        self.assertAllEqual(sess.run(ts), 3)
开发者ID:LiuCKind,项目名称:tensorflow,代码行数:26,代码来源:lists_test.py


示例10: testAddN

 def testAddN(self):
   l1 = list_ops.tensor_list_from_tensor([1.0, 2.0], element_shape=[])
   l2 = list_ops.tensor_list_from_tensor([3.0, 4.0], element_shape=[])
   l3 = list_ops.tensor_list_from_tensor([5.0, 6.0], element_shape=[])
   result = math_ops.add_n((l1, l2, l3))
   result_t = list_ops.tensor_list_stack(result, element_dtype=dtypes.float32)
   self.assertAllEqual(self.evaluate(result_t), [9., 12.])
开发者ID:aeverall,项目名称:tensorflow,代码行数:7,代码来源:list_ops_test.py


示例11: testPruning

  def testPruning(self):
    x = constant_op.constant(1)

    tensor_list = list_ops.empty_tensor_list(
        element_dtype=x.dtype, element_shape=x.shape)

    def Cond(x, tl):
      del tl  # Unused for Cond.
      return x < 5

    def Body(x, tl):
      return x + 1, list_ops.tensor_list_push_back(tl, x)

    outputs = while_loop_v1(Cond, Body, [x, tensor_list])

    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    train_op.append(outputs[0])

    def GetOptimizedGraph():
      mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
      rewriter_config = rewriter_config_pb2.RewriterConfig(
          constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
          memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
      return tf_optimizer.OptimizeGraph(rewriter_config, mg)

    g = GetOptimizedGraph()
    self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 1)

    stack = list_ops.tensor_list_stack(outputs[1], element_dtype=x.dtype)
    train_op.append(stack)
    g = GetOptimizedGraph()
    self.assertEqual(len([n for n in g.node if n.op == "Enter"]), 2)
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:32,代码来源:while_v2_test.py


示例12: testSetStackReservedUnknownElementShape

 def testSetStackReservedUnknownElementShape(self):
   with self.cached_session(), self.test_scope():
     l = list_ops.tensor_list_reserve(
         element_dtype=dtypes.float32, element_shape=None, num_elements=2)
     l = list_ops.tensor_list_set_item(l, 0, [3.0, 4.0])
     t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
     self.assertAllEqual(t, [[3.0, 4.0], [0., 0.]])
开发者ID:jackd,项目名称:tensorflow,代码行数:7,代码来源:tensor_list_ops_test.py


示例13: testStackWithUnknownElementShape

  def testStackWithUnknownElementShape(self, max_num_elements):
    l = list_ops.empty_tensor_list(
        element_dtype=dtypes.float32,
        element_shape=None,
        max_num_elements=max_num_elements)
    l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
    l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))

    t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
    self.assertAllEqual(self.evaluate(t), [1.0, 2.0])

    # Should raise an error when the element tensors do not all have the same
    # shape.
    with self.assertRaisesRegexp(errors.InvalidArgumentError, "unequal shapes"):
      l = list_ops.tensor_list_push_back(l, constant_op.constant([3.0, 4.0]))
      t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
      self.evaluate(t)
开发者ID:aeverall,项目名称:tensorflow,代码行数:17,代码来源:list_ops_test.py


示例14: testGetSetItem

 def testGetSetItem(self):
   t = constant_op.constant([1.0, 2.0])
   l = list_ops.tensor_list_from_tensor(t, element_shape=[])
   e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
   self.assertAllEqual(self.evaluate(e0), 1.0)
   l = list_ops.tensor_list_set_item(l, 0, 3.0)
   t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
   self.assertAllEqual(self.evaluate(t), [3.0, 2.0])
开发者ID:aeverall,项目名称:tensorflow,代码行数:8,代码来源:list_ops_test.py


示例15: test_append_tensor_list

  def test_append_tensor_list(self):
    l = data_structures.new_list()
    x = constant_op.constant([1, 2, 3])
    l = data_structures.list_append(l, x)

    t = list_ops.tensor_list_stack(l, element_dtype=x.dtype)
    with self.test_session() as sess:
      self.assertAllEqual(sess.run(t), [[1, 2, 3]])
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:8,代码来源:data_structures_test.py


示例16: stack

 def stack(self, name=None):
   """See TensorArray."""
   with ops.name_scope(name, "TensorArrayV2Stack", [self._flow]):
     value = list_ops.tensor_list_stack(
         input_handle=self._flow, element_dtype=self._dtype)
     if self._element_shape and self._element_shape[0].dims is not None:
       value.set_shape([None] + self._element_shape[0].dims)
     return value
开发者ID:terrytangyuan,项目名称:tensorflow,代码行数:8,代码来源:tensor_array_ops.py


示例17: testGraphStack

 def testGraphStack(self):
   with context.graph_mode(), self.test_session():
     tl = list_ops.empty_tensor_list(
         element_shape=constant_op.constant([1], dtype=dtypes.int32),
         element_dtype=dtypes.int32)
     tl = list_ops.tensor_list_push_back(tl, [1])
     self.assertAllEqual(
         list_ops.tensor_list_stack(tl, element_dtype=dtypes.int32).eval(),
         [[1]])
开发者ID:DILASSS,项目名称:tensorflow,代码行数:9,代码来源:list_ops_test.py


示例18: testGetSet

 def testGetSet(self):
   with self.cached_session(), self.test_scope():
     t = constant_op.constant([1.0, 2.0])
     l = list_ops.tensor_list_from_tensor(t, element_shape=[])
     e0 = list_ops.tensor_list_get_item(l, 0, element_dtype=dtypes.float32)
     self.assertAllEqual(e0, 1.0)
     l = list_ops.tensor_list_set_item(l, 0, 3.0)
     t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
     self.assertAllEqual(t, [3.0, 2.0])
开发者ID:jackd,项目名称:tensorflow,代码行数:9,代码来源:tensor_list_ops_test.py


示例19: _testStack

 def _testStack(self, max_num_elements):
   l = list_ops.empty_tensor_list(
       element_dtype=dtypes.float32,
       element_shape=scalar_shape(),
       max_num_elements=max_num_elements)
   l = list_ops.tensor_list_push_back(l, constant_op.constant(1.0))
   l = list_ops.tensor_list_push_back(l, constant_op.constant(2.0))
   t = list_ops.tensor_list_stack(l, element_dtype=dtypes.float32)
   self.assertAllEqual(self.evaluate(t), [1.0, 2.0])
开发者ID:abhinav-upadhyay,项目名称:tensorflow,代码行数:9,代码来源:list_ops_test.py


示例20: testStackFromTensorGradients

 def testStackFromTensorGradients(self):
   with backprop.GradientTape() as tape:
     c = constant_op.constant([1.0, 2.0])
     tape.watch(c)
     l = list_ops.tensor_list_from_tensor(c, element_shape=scalar_shape())
     c2 = list_ops.tensor_list_stack(
         l, element_dtype=dtypes.float32)
     result = c2 * 2.0
   self.assertAllEqual(tape.gradient(result, [c])[0], [2.0, 2.0])
开发者ID:andrewharp,项目名称:tensorflow,代码行数:9,代码来源:list_ops_test.py



注:本文中的tensorflow.python.ops.list_ops.tensor_list_stack函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python logging_ops.get_summary_op函数代码示例发布时间:2022-05-27
下一篇:
Python list_ops.tensor_list_set_item函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap