• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python control_flow_ops.merge函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.ops.control_flow_ops.merge函数的典型用法代码示例。如果您正苦于以下问题:Python merge函数的具体用法?Python merge怎么用?Python merge使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了merge函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testMergeShapes

  def testMergeShapes(self):
    # All inputs unknown.
    p1 = tf.placeholder(tf.float32)
    p2 = tf.placeholder(tf.float32)
    p3 = tf.placeholder(tf.float32)
    m, index = control_flow_ops.merge([p1, p2, p3])
    self.assertIs(None, m.get_shape().ndims)
    self.assertEqual([], index.get_shape())

    # All inputs known but different.
    p1 = tf.placeholder(tf.float32, shape=[1, 2])
    p2 = tf.placeholder(tf.float32, shape=[2, 1])
    m, index = control_flow_ops.merge([p1, p2])
    self.assertIs(None, m.get_shape().ndims)
    self.assertEqual([], index.get_shape())

    # All inputs known but same.
    p1 = tf.placeholder(tf.float32, shape=[1, 2])
    p2 = tf.placeholder(tf.float32, shape=[1, 2])
    m, index = control_flow_ops.merge([p1, p2])
    self.assertEqual([1, 2], m.get_shape())
    self.assertEqual([], index.get_shape())

    # Possibly the same but not guaranteed.
    p1 = tf.placeholder(tf.float32, shape=[1, 2])
    p2 = tf.placeholder(tf.float32)
    p2.set_shape([None, 2])
    m, index = control_flow_ops.merge([p1, p2])
    self.assertIs(None, m.get_shape().ndims)
    self.assertEqual([], index.get_shape())
开发者ID:hypatiad,项目名称:tensorflow,代码行数:30,代码来源:control_flow_ops_py_test.py


示例2: testLoop_1

    def testLoop_1(self):
        with self.test_session():
            zero = tf.convert_to_tensor(0)
            one = tf.convert_to_tensor(1)
            n = tf.constant(10)

            enter_zero = control_flow_ops.enter(zero, "foo_1", False)
            enter_one = control_flow_ops.enter(one, "foo_1", False)
            enter_n = control_flow_ops.enter(n, "foo_1", False)
            merge_zero = control_flow_ops.merge([enter_zero, enter_zero], name="merge_zero")[0]
            merge_one = control_flow_ops.merge([enter_one, enter_one], name="merge_one")[0]
            merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0]
            less_op = tf.less(merge_n, merge_n)
            cond_op = control_flow_ops.loop_cond(less_op)
            switch_zero = control_flow_ops.switch(merge_zero, cond_op)
            switch_one = control_flow_ops.switch(merge_one, cond_op)
            switch_n = control_flow_ops.switch(merge_n, cond_op)
            next_zero = control_flow_ops.next_iteration(switch_zero[1])
            next_one = control_flow_ops.next_iteration(switch_one[1])
            next_n = control_flow_ops.next_iteration(switch_n[1])
            merge_zero.op._update_input(1, next_zero)
            merge_one.op._update_input(1, next_one)
            merge_n.op._update_input(1, next_n)
            exit_n = control_flow_ops.exit(switch_n[0])

            result = exit_n.eval()
        self.assertAllEqual(10, result)
开发者ID:peace195,项目名称:tensorflow,代码行数:27,代码来源:control_flow_ops_py_test.py


示例3: testLoop_2

  def testLoop_2(self):
    with self.test_session():
      zero = tf.constant(0)
      one = tf.constant(1)
      n = tf.constant(10)

      enter_i = control_flow_ops.enter(zero, "foo", False)
      enter_one = control_flow_ops.enter(one, "foo", True)
      enter_n = control_flow_ops.enter(n, "foo", True)

      merge_i = control_flow_ops.merge([enter_i, enter_i])[0]

      less_op = tf.less(merge_i, enter_n)
      cond_op = control_flow_ops.loop_cond(less_op)
      switch_i = control_flow_ops.switch(merge_i, cond_op)

      add_i = tf.add(switch_i[1], enter_one)

      with tf.device("/gpu:0"):
        next_i = control_flow_ops.next_iteration(add_i)
      merge_i.op._update_input(1, next_i)

      exit_i = control_flow_ops.exit(switch_i[0])
      result = exit_i.eval()
    self.assertAllEqual(10, result)
开发者ID:hypatiad,项目名称:tensorflow,代码行数:25,代码来源:control_flow_ops_py_test.py


示例4: apply_with_random_selector

def apply_with_random_selector(image, func, num_cases):
    """random select a mode case to func(image, case)"""
    # random select a mode
    sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
    return control_flow_ops.merge([
        func(control_flow_ops.switch(image, tf.equal(case, sel))[1], case)
         for case in range(num_cases)])[0]
开发者ID:beacandler,项目名称:tf-slim-demo,代码行数:7,代码来源:inception_preprocessing.py


示例5: _process_switch

  def _process_switch(self, switch_op, ops_which_must_run,
                      last_op_using_resource_tensor, merge_for_resource):
    """Processes a switch node for a resource input.

    When tensorflow creates a cond, it creates a control flow context for each
    branch of the cond. Each external tensor accessed by that branch is routed
    through a switch op, which gets created in the graph _after_ the op which
    uses that tensor get created.

    If the resource comes from another switch op we process that one first.

    _process_switch creates a corresponding merge node for the switch node. This
    merge node is added to the outer control flow context of the switch
    node. We also ensure that:

      1. The switch node executes after the previous op which used the resource
         tensor

      2. Any op which uses a resource output of the switch node executes before
         the merge for the switch node.

      3. The next op which uses the input resource to the switch node (which
         might be another switch node for the other branch of the conditional)
         will execute after the merge node is done.

      4. The merge node is marked as must_run so it will run even if no
         subsequent operation uses the resource.

    Args:
      switch_op: the switch op to be processed
      ops_which_must_run: the set of ops which must run
      last_op_using_resource_tensor: map from resource tensor to last op using
        it
      merge_for_resource: map from resource tensor to merge which must follow
        all usages of it.
    """
    inp = switch_op.inputs[0]
    if inp.dtype == dtypes_module.resource and inp.op.type == "Switch":
      self._process_switch(inp.op, ops_which_must_run,
                           last_op_using_resource_tensor, merge_for_resource)
    if switch_op.outputs[0] in merge_for_resource:
      return
    new_merge = control_flow_ops.merge(switch_op.outputs,
                                       name="artificial_merge")
    new_merge[0].op._control_flow_context = (  # pylint: disable=protected-access
        switch_op._control_flow_context.outer_context)  # pylint: disable=protected-access
    # Ensures the merge always runs
    ops_which_must_run.add(new_merge[0].op)
    if inp in last_op_using_resource_tensor:
      # Ensures the switch exectutes after the previous op using the resource.
      switch_op._add_control_input(last_op_using_resource_tensor[inp])  # pylint: disable=protected-access
    # Ensure the next op outside the cond happens after the merge.
    last_op_using_resource_tensor[inp] = new_merge[0].op
    if inp in merge_for_resource:
      merge_for_resource[inp]._add_control_input(new_merge[0].op)  # pylint: disable=protected-access
    for o in switch_op.outputs:
      # Ensures the merge will execute after all ops inside the cond
      merge_for_resource[o] = new_merge[0].op
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:58,代码来源:function.py


示例6: _testSwitchMerge_1

  def _testSwitchMerge_1(self, use_gpu):
    with self.test_session(use_gpu=use_gpu):
      data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
      ports = tf.convert_to_tensor(True, name="ports")
      switch_op = control_flow_ops.switch(data, ports)
      merge_op = control_flow_ops.merge(switch_op)[0]

      result = merge_op.eval()
    self.assertAllEqual(np.arange(1, 7), result)
开发者ID:hypatiad,项目名称:tensorflow,代码行数:9,代码来源:control_flow_ops_py_test.py


示例7: testSwitchMergeIdentity_1

  def testSwitchMergeIdentity_1(self):
    with self.test_session():
      data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
      ports = tf.convert_to_tensor(True, name="ports")
      switch_op = control_flow_ops.switch(data, ports)
      merge_op = control_flow_ops.merge(switch_op)[0]
      id_op = tf.identity(merge_op)

      result = id_op.eval()
    self.assertAllEqual(np.arange(1, 7), result)
开发者ID:hypatiad,项目名称:tensorflow,代码行数:10,代码来源:control_flow_ops_py_test.py


示例8: testSwitchMergeLess_1

  def testSwitchMergeLess_1(self):
    with self.test_session():
      data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
      zero = tf.convert_to_tensor(0)
      one = tf.convert_to_tensor(1)
      less_op = tf.less(zero, one)
      switch_op = control_flow_ops.switch(data, less_op)
      merge_op = control_flow_ops.merge(switch_op)[0]

      result = merge_op.eval()
    self.assertAllEqual(np.arange(1, 7), result)
开发者ID:hypatiad,项目名称:tensorflow,代码行数:11,代码来源:control_flow_ops_py_test.py


示例9: testSwitchMergeAddIdentity_1

  def testSwitchMergeAddIdentity_1(self):
    with self.test_session():
      data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
      ports = tf.convert_to_tensor(True, name="ports")
      switch_op = control_flow_ops.switch(data, ports)
      one = tf.constant(1)
      add_op = tf.add(switch_op[0], one)
      id_op = tf.identity(switch_op[1])
      merge_op = control_flow_ops.merge([add_op, id_op])[0]

      result = merge_op.eval()
    self.assertAllEqual(np.arange(1, 7), result)
开发者ID:hypatiad,项目名称:tensorflow,代码行数:12,代码来源:control_flow_ops_py_test.py


示例10: testSwitchMergeAddMul_1

  def testSwitchMergeAddMul_1(self):
    with self.test_session():
      data = tf.constant([1, 2, 3, 4, 5, 6], name="data")
      ports = tf.convert_to_tensor(True, name="ports")
      switch_op = control_flow_ops.switch(data, ports)
      one = tf.constant(1)
      add_op = tf.add(switch_op[0], one)
      five = tf.constant(5)
      mul_op = tf.mul(switch_op[1], five)
      merge_op = control_flow_ops.merge([add_op, mul_op])[0]

      result = merge_op.eval()
    self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
开发者ID:hypatiad,项目名称:tensorflow,代码行数:13,代码来源:control_flow_ops_py_test.py


示例11: testSwitchMergeIndexedSlices

  def testSwitchMergeIndexedSlices(self):
    with self.test_session():
      values = tf.constant([1, 2, 3, 4, 5, 6])
      indices = tf.constant([0, 2, 4, 6, 8, 10])
      data = tf.IndexedSlices(values, indices)
      pred = tf.convert_to_tensor(True)
      switch_op = control_flow_ops.switch(data, pred)
      merge_op = control_flow_ops.merge(switch_op)[0]

      val = merge_op.values.eval()
      ind = merge_op.indices.eval()
    self.assertAllEqual(np.arange(1, 7), val)
    self.assertAllEqual(np.arange(0, 12, 2), ind)
开发者ID:hypatiad,项目名称:tensorflow,代码行数:13,代码来源:control_flow_ops_py_test.py


示例12: testLoop_false

    def testLoop_false(self):
        with self.test_session():
            false = tf.convert_to_tensor(False)
            n = tf.constant(10)

            enter_false = control_flow_ops.enter(false, "foo_1", False)
            enter_n = control_flow_ops.enter(n, "foo_1", False)

            merge_n = control_flow_ops.merge([enter_n], name="merge_n")[0]
            switch_n = control_flow_ops.switch(merge_n, enter_false)
            exit_n = control_flow_ops.exit(switch_n[0])

            result = exit_n.eval()
        self.assertAllEqual(10, result)
开发者ID:peace195,项目名称:tensorflow,代码行数:14,代码来源:control_flow_ops_py_test.py


示例13: apply_with_random_selector

def apply_with_random_selector(x, func, num_cases):
  """Computes func(x, sel), with sel sampled from [0...num_cases-test].
      Args:
        x: input Tensor.
        func: Python function to apply.
        num_cases: Python int32, number of cases to sample sel from.
      Returns:
        The result of func(x, sel), where func receives the value of the
        selector as a python integer, but sel is sampled dynamically.
      """
  sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
  # Pass the real x only to one of the func calls.
  return control_flow_ops.merge([
      func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
      for case in range(num_cases)
  ])[0]
开发者ID:veyvin,项目名称:tensorflow-learn,代码行数:16,代码来源:image_pre_test.py


示例14: apply_with_random_selector

def apply_with_random_selector(x, func, num_cases):
  """Computes func(x, sel), with sel sampled from [0...num_cases-1].

  TODO(coreylynch): add as a dependency, when slim or tensorflow/models are
  pipfied.
  Source:
  https://raw.githubusercontent.com/tensorflow/models/a9d0e6e8923a4/slim/preprocessing/inception_preprocessing.py

  Args:
    x: input Tensor.
    func: Python function to apply.
    num_cases: Python int32, number of cases to sample sel from.
  Returns:
    The result of func(x, sel), where func receives the value of the
    selector as a python integer, but sel is sampled dynamically.
  """
  sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
  # Pass the real x only to one of the func calls.
  return control_flow_ops.merge([
      func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
      for case in range(num_cases)])[0]
开发者ID:NoPointExc,项目名称:models,代码行数:21,代码来源:preprocessing.py


示例15: create_op


#.........这里部分代码省略.........
    if self._return_as_is or op_type in _PASS_THROUGH_OPS:
      return self._wrap(super(ImperativeGraph, self).create_op(*args, **kwargs))

    if not output_dtypes:
      return self._wrap(
          super(ImperativeGraph, self).create_op(*args, **kwargs))

    output_has_ref = any([dtype._is_ref_dtype for dtype in output_dtypes])  # pylint: disable=protected-access

    if output_has_ref:
      if op_type not in _REF_OPS_WHITELIST:
        raise errors.UnimplementedError(None, None,
                                        op_type + ' op not supported in '
                                        'imperative graph')

      ret = super(ImperativeGraph, self).create_op(*args, **kwargs)

      if self._in_variable_creation:
        if op_type == 'Assign':
          self.add_pending_init(ret)

      return self._wrap(ret)

    with self.return_as_is():
      # Declares the variables to hold the output values of this op.
      op_output_var = [state_ops.variable_op_v2(
          tensor_shape.TensorShape(None), dtype, container=self._name)
                       for dtype in output_dtypes]
      # Ops to free the resources used by the temporary cache variables.
      # The following two ops are created for each cache variable,
      # having no control dependencies on any other ops :
      # var_handle_op ----> destroy_resource_op
      for dtype, v in zip(output_dtypes, op_output_var):
        with ops.control_dependencies(None):
          self._variable_cleanup_ops += [
              gen_resource_variable_ops.destroy_resource_op(
                  gen_resource_variable_ops.var_handle_op(
                      dtype, tensor_shape.TensorShape(None),
                      container=self._name, shared_name=v.op.name),
                  ignore_lookup_error=True)]

      # Create the conditional to run the original op only when the variable
      # corresponding to the first output is not initialized.
      inited = state_ops.is_variable_initialized(op_output_var[0])
      v_f, v_t = control_flow_ops.ref_switch(op_output_var[0], inited)
      # pylint: disable=protected-access
      v_f_op = gen_array_ops._ref_identity(v_f)
      v_t_op = gen_array_ops._ref_identity(v_t)
      # pylint: enable=protected-access

      with ops.control_dependencies([v_f_op.op]):
        # Create the original op
        orig_op = self._wrap(
            super(ImperativeGraph, self).create_op(*args, **kwargs))
      shapes = [val.get_shape() for val in orig_op.outputs]

      controls = []
      for var, val in zip(op_output_var, orig_op.outputs):
        if (not val.get_shape().is_fully_defined() or
            val.get_shape().num_elements() > 0):
          assign_op = state_ops.assign(var, val, validate_shape=False)
          assign_op.set_shape(val.get_shape())
          controls.append(assign_op)

      values = []
      if len(controls) > 1:
        if control_flow_ops.IsSwitch(orig_op):
          # pylint: disable=protected-access
          controls = gen_control_flow_ops._ref_merge(controls)
          # pylint: enable=protected-access
        else:
          controls = control_flow_ops.tuple(controls)

      for var, val in zip(op_output_var, orig_op.outputs):
        with ops.control_dependencies(controls):
          with self.colocate_with(v_f_op):
            real_val = array_ops.identity(val)
        with ops.control_dependencies([v_t_op.op]):
          with self.colocate_with(v_t_op):
            stored_val = array_ops.identity(var)
          stored_val.set_shape(val.get_shape())
          real_val, _ = control_flow_ops.merge([real_val, stored_val])
        real_val.op.node_def.attr['_gradient_op_type'].CopyFrom(
            attr_value_pb2.AttrValue(s=compat.as_bytes(self._merge_op_type)))
        values.append(real_val)

      for i, _ in enumerate(shapes):
        values[i].set_shape(shapes[i])
      self._outputs_map[orig_op.name] = values
      try:
        self._gradient_function_map[orig_op.name] = ops.get_gradient_function(
            orig_op)
      except (KeyError, LookupError):
        pass
      else:
        orig_op.node_def.attr['_gradient_op_type'].CopyFrom(
            attr_value_pb2.AttrValue(
                s=compat.as_bytes(self._imperative_op_type)))

      return MultiOutputOperation(values, orig_op)
开发者ID:chdinh,项目名称:tensorflow,代码行数:101,代码来源:imperative_graph.py



注:本文中的tensorflow.python.ops.control_flow_ops.merge函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python control_flow_ops.no_op函数代码示例发布时间:2022-05-27
下一篇:
Python control_flow_ops.group函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap