• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python control_flow_ops.case函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.ops.control_flow_ops.case函数的典型用法代码示例。如果您正苦于以下问题:Python case函数的具体用法?Python case怎么用?Python case使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了case函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: _testReturnValues

  def _testReturnValues(self, fn_true, fn_false, expected_value_true,
                        expected_value_false, strict=False,
                        check_cond=True, feed_dict=None):
    if feed_dict is None: feed_dict = {}

    condition = array_ops.placeholder(dtypes.bool)
    output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
                                        strict=strict)
    output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
                                        strict=strict)

    with self.test_session() as sess:
      variables.global_variables_initializer().run()
      true_feed_dict = {condition: True}
      true_feed_dict.update(feed_dict)
      result_cond, result_case = sess.run([output_cond, output_case],
                                          feed_dict=true_feed_dict)
      self.assertAllEqualNested(result_cond, expected_value_true)
      if check_cond:
        self.assertAllEqualNested(result_case, expected_value_true)
      false_feed_dict = {condition: False}
      false_feed_dict.update(feed_dict)
      result_cond, result_case = sess.run([output_cond, output_case],
                                          feed_dict=false_feed_dict)
      self.assertAllEqualNested(result_cond, expected_value_false)
      if check_cond:
        self.assertAllEqualNested(result_case, expected_value_false)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:27,代码来源:control_flow_ops_test.py


示例2: _decode

  def _decode(self, image_buffer, image_format):
    """Decodes the image buffer.

    Args:
      image_buffer: T tensor representing the encoded image tensor.
      image_format: The image format for the image in `image_buffer`.

    Returns:
      A decoder image.
    """
    def decode_png():
      return image_ops.decode_png(image_buffer, self._channels)
    def decode_raw():
      return parsing_ops.decode_raw(image_buffer, dtypes.uint8)
    def decode_jpg():
      return image_ops.decode_jpeg(image_buffer, self._channels)

    image = control_flow_ops.case({
        math_ops.logical_or(math_ops.equal(image_format, 'png'),
                            math_ops.equal(image_format, 'PNG')): decode_png,
        math_ops.logical_or(math_ops.equal(image_format, 'raw'),
                            math_ops.equal(image_format, 'RAW')): decode_raw,
    }, default=decode_jpg, exclusive=True)

    image.set_shape([None, None, self._channels])
    if self._shape is not None:
      image = array_ops.reshape(image, self._shape)

    return image
开发者ID:821760408-sp,项目名称:tensorflow,代码行数:29,代码来源:tfexample_decoder.py


示例3: _decode

  def _decode(self, image_buffer, image_format):
    """Decodes the image buffer.

    Args:
      image_buffer: The tensor representing the encoded image tensor.
      image_format: The image format for the image in `image_buffer`. If image
        format is `raw`, all images are expected to be in this format, otherwise
        this op can decode a mix of `jpg` and `png` formats.

    Returns:
      A tensor that represents decoded image of self._shape, or
      (?, ?, self._channels) if self._shape is not specified.
    """
    def decode_image():
      """Decodes a png or jpg based on the headers."""
      return image_ops.decode_image(image_buffer, self._channels)

    def decode_raw():
      """Decodes a raw image."""
      return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)

    pred_fn_pairs = {
        math_ops.logical_or(
            math_ops.equal(image_format, 'raw'),
            math_ops.equal(image_format, 'RAW')): decode_raw,
    }
    image = control_flow_ops.case(
        pred_fn_pairs, default=decode_image, exclusive=True)

    image.set_shape([None, None, self._channels])
    if self._shape is not None:
      image = array_ops.reshape(image, self._shape)

    return image
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:34,代码来源:tfexample_decoder.py


示例4: rot90

def rot90(image, k=1, name=None):
  """Rotate an image counter-clockwise by 90 degrees.

  Args:
    image: A 3-D tensor of shape `[height, width, channels]`.
    k: A scalar integer. The number of times the image is rotated by 90 degrees.
    name: A name for this operation (optional).

  Returns:
    A rotated 3-D tensor of the same type and shape as `image`.
  """
  with ops.name_scope(name, 'rot90', [image, k]) as scope:
    image = ops.convert_to_tensor(image, name='image')
    _Check3DImage(image, require_static=False)
    k = ops.convert_to_tensor(k, dtype=dtypes.int32, name='k')
    k.get_shape().assert_has_rank(0)
    k = math_ops.mod(k, 4)

    def _rot90():
      return array_ops.transpose(array_ops.reverse_v2(image, [1]),
                                 [1, 0, 2])
    def _rot180():
      return array_ops.reverse_v2(image, [0, 1])
    def _rot270():
      return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]),
                                  [1])
    cases = [(math_ops.equal(k, 1), _rot90),
             (math_ops.equal(k, 2), _rot180),
             (math_ops.equal(k, 3), _rot270)]

    ret = control_flow_ops.case(cases, default=lambda: image, exclusive=True,
                                name=scope)
    ret.set_shape([None, None, image.get_shape()[2]])
    return ret
开发者ID:kdavis-mozilla,项目名称:tensorflow,代码行数:34,代码来源:image_ops_impl.py


示例5: test_inv_update_thunks

  def test_inv_update_thunks(self):
    """Ensures inverse update ops run once per global_step."""
    with self._graph.as_default(), self.test_session() as sess:
      fisher_estimator = estimator.FisherEstimator(
          damping_fn=lambda: 0.2,
          variables=[self.weights],
          layer_collection=self.layer_collection,
          cov_ema_decay=0.0)

      # Construct op that updates one inverse per global step.
      global_step = training_util.get_or_create_global_step()
      inv_matrices = [
          matrix
          for fisher_factor in self.layer_collection.get_factors()
          for matrix in fisher_factor._inverses_by_damping.values()
      ]
      inv_update_op_thunks = fisher_estimator.inv_update_thunks
      inv_update_op = control_flow_ops.case(
          [(math_ops.equal(global_step, i), thunk)
           for i, thunk in enumerate(inv_update_op_thunks)])
      increment_global_step = global_step.assign_add(1)

      sess.run(variables.global_variables_initializer())
      initial_inv_values = sess.run(inv_matrices)

      # Ensure there's one update per inverse matrix. This is true as long as
      # there's no fan-in/fan-out or parameter re-use.
      self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))

      # Test is no-op if only 1 invariance matrix.
      assert len(inv_matrices) > 1

      # Assign each covariance matrix a value other than the identity. This
      # ensures that the inverse matrices are updated to something different as
      # well.
      cov_matrices = [
          fisher_factor.get_cov()
          for fisher_factor in self.layer_collection.get_factors()
      ]
      sess.run([
          cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0])))
          for cov_matrix in cov_matrices
      ])

      for i in range(len(inv_matrices)):
        # Compare new and old inverse values
        new_inv_values = sess.run(inv_matrices)
        is_inv_equal = [
            np.allclose(initial_inv_value, new_inv_value)
            for (initial_inv_value,
                 new_inv_value) in zip(initial_inv_values, new_inv_values)
        ]
        num_inv_equal = sum(is_inv_equal)

        # Ensure exactly one inverse matrix changes per step.
        self.assertEqual(num_inv_equal, len(inv_matrices) - i)

        # Run all inverse update ops.
        sess.run(inv_update_op)
        sess.run(increment_global_step)
开发者ID:QiangCai,项目名称:tensorflow,代码行数:60,代码来源:estimator_test.py


示例6: testCase_dict

 def testCase_dict(self):
   x = constant_op.constant(2)
   conditions = {
       math_ops.equal(x, 1): lambda: constant_op.constant(2),
       math_ops.equal(x, 2): lambda: constant_op.constant(4)
   }
   output = control_flow_ops.case(conditions, exclusive=True)
   self.assertEqual(4, self.evaluate(output))
开发者ID:bunbutter,项目名称:tensorflow,代码行数:8,代码来源:control_flow_ops_test.py


示例7: testCase_withoutDefault_oneCondition

 def testCase_withoutDefault_oneCondition(self):
   x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
   conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2))]
   output = control_flow_ops.case(conditions, exclusive=True)
   with self.test_session() as sess:
     self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
     with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
       sess.run(output, feed_dict={x: 4})
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:8,代码来源:control_flow_ops_test.py


示例8: piecewise_constant

def piecewise_constant(x, boundaries, values, name=None):
  """ Piecewise constant from boundaries and interval values.

  Example: use a learning rate that's 1.0 for the first 100000 steps, 0.5
    for steps 100001 to 110000, and 0.1 for any additional steps.

  ```python
  global_step = tf.Variable(0, trainable=False)
  boundaries = [100000, 110000]
  values = [1.0, 0.5, 0.1]
  learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)

  # Later, whenever we perform an optimization step, we increment global_step.
  ```

  Args:
    x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`,
      `float64`, `uint8`, `int8`, `int16`, `int32`, `int64`.
    boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
      increasing entries, and with all elements having the same type as `x`.
    values: A list of `Tensor`s or float`s or `int`s that specifies the values
      for the intervals defined by `boundaries`. It should have one more element
      than `boundaries`, and all elements should have the same type.
    name: A string. Optional name of the operation. Defaults to
      'PiecewiseConstant'.

  Returns:
    A 0-D Tensor. Its value is `values[0]` when `x <= boundaries[0]`,
    `values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ...,
    and values[-1] when `x > boundaries[-1]`.
  """

  with ops.name_scope(name, 'PiecewiseConstant',
                      [x, boundaries, values, name]) as name:
    x = ops.convert_to_tensor(x)
    # Avoid explicit conversion to x's dtype. This could result in faulty
    # comparisons, for example if floats are converted to integers.
    boundaries = ops.convert_n_to_tensor(boundaries)
    if not all(b.dtype == x.dtype for b in boundaries):
      raise ValueError('boundaries must have the same dtype as x.')
    # TODO(rdipietro): Ensure that boundaries' elements are strictly increasing.
    values = ops.convert_n_to_tensor(values)
    if not all(v.dtype == values[0].dtype for v in values):
      raise ValueError('values must have elements all with the same dtype.')

    pred_fn_pairs = {}
    pred_fn_pairs[x <= boundaries[0]] = lambda: values[0]
    pred_fn_pairs[x > boundaries[-1]] = lambda: values[-1]
    for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
      # Need to bind v here; can do this with lambda v=v: ...
      pred = (x > low) & (x <= high)
      pred_fn_pairs[pred] = lambda v=v: v

    # The default isn't needed here because our conditions are mutually
    # exclusive and exhaustive, but tf.case requires it.
    default = lambda: values[0]
    return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
开发者ID:AriaAsuka,项目名称:tensorflow,代码行数:57,代码来源:learning_rate_decay.py


示例9: testCase_withDefault

 def testCase_withDefault(self):
   x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
   conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
                 (math_ops.equal(x, 2), lambda: constant_op.constant(4))]
   default = lambda: constant_op.constant(6)
   output = control_flow_ops.case(conditions, default, exclusive=True)
   with self.test_session() as sess:
     self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
     self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
     self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:10,代码来源:control_flow_ops_test.py


示例10: _decode

  def _decode(self, image_buffer, image_format):
    """Decodes the image buffer.

    Args:
      image_buffer: The tensor representing the encoded image tensor.
      image_format: The image format for the image in `image_buffer`. If image
        format is `raw`, all images are expected to be in this format, otherwise
        this op can decode a mix of `jpg` and `png` formats.

    Returns:
      A tensor that represents decoded image of self._shape, or
      (?, ?, self._channels) if self._shape is not specified.
    """

    def decode_image():
      """Decodes a image based on the headers."""
      return math_ops.cast(
          image_ops.decode_image(image_buffer, channels=self._channels),
          self._dtype)

    def decode_jpeg():
      """Decodes a jpeg image with specified '_dct_method'."""
      return math_ops.cast(
          image_ops.decode_jpeg(
              image_buffer,
              channels=self._channels,
              dct_method=self._dct_method), self._dtype)

    def check_jpeg():
      """Checks if an image is jpeg."""
      # For jpeg, we directly use image_ops.decode_jpeg rather than decode_image
      # in order to feed the jpeg specify parameter 'dct_method'.
      return control_flow_ops.cond(
          image_ops.is_jpeg(image_buffer),
          decode_jpeg,
          decode_image,
          name='cond_jpeg')

    def decode_raw():
      """Decodes a raw image."""
      return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)

    pred_fn_pairs = {
        math_ops.logical_or(
            math_ops.equal(image_format, 'raw'),
            math_ops.equal(image_format, 'RAW')): decode_raw,
    }
    image = control_flow_ops.case(
        pred_fn_pairs, default=check_jpeg, exclusive=True)

    image.set_shape([None, None, self._channels])
    if self._shape is not None:
      image = array_ops.reshape(image, self._shape)

    return image
开发者ID:Albert-Z-Guo,项目名称:tensorflow,代码行数:55,代码来源:tfexample_decoder.py


示例11: test_singleton_strict

  def test_singleton_strict(self):
    fn_tensor = lambda: constant_op.constant(1)
    fn_list = lambda: [constant_op.constant(2)]
    fn_tuple = lambda: (constant_op.constant(3),)

    with self.assertRaises(ValueError):
      control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_list,
                            strict=True)

    with self.assertRaises(TypeError):
      control_flow_ops.cond(constant_op.constant(True), fn_list, fn_tuple,
                            strict=True)

    with self.assertRaises(ValueError):
      control_flow_ops.case([(constant_op.constant(True), fn_tensor)], fn_list,
                            strict=True)

    with self.assertRaises(TypeError):
      control_flow_ops.case([(constant_op.constant(True), fn_list)], fn_tuple,
                            strict=True)
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:20,代码来源:control_flow_ops_test.py


示例12: _decode

  def _decode(self, image_buffer, image_format):
    """Decodes the image buffer.

    Args:
      image_buffer: The tensor representing the encoded image tensor.
      image_format: The image format for the image in `image_buffer`.

    Returns:
      A tensor that represents decoded image of self._shape, or
      (?, ?, self._channels) if self._shape is not specified.
    """

    def decode_png():
      return image_ops.decode_png(
          image_buffer, self._channels, dtype=self._dtype)

    def decode_raw():
      return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)

    def decode_jpg():
      if self._dtype != dtypes.uint8:
        raise ValueError(
            'jpeg decoder can only be used to decode to tf.uint8 but %s was '
            'requested for a jpeg image.' % self._dtype)
      return image_ops.decode_jpeg(image_buffer, self._channels)

    # For RGBA images JPEG is not a valid decoder option.
    if self._channels > 3:
      pred_fn_pairs = {
          math_ops.logical_or(
              math_ops.equal(image_format, 'raw'),
              math_ops.equal(image_format, 'RAW')): decode_raw,
      }
      default_decoder = decode_png
    else:
      pred_fn_pairs = {
          math_ops.logical_or(
              math_ops.equal(image_format, 'png'),
              math_ops.equal(image_format, 'PNG')): decode_png,
          math_ops.logical_or(
              math_ops.equal(image_format, 'raw'),
              math_ops.equal(image_format, 'RAW')): decode_raw,
      }
      default_decoder = decode_jpg

    image = control_flow_ops.case(
        pred_fn_pairs, default=default_decoder, exclusive=True)

    image.set_shape([None, None, self._channels])
    if self._shape is not None:
      image = array_ops.reshape(image, self._shape)

    return image
开发者ID:LUTAN,项目名称:tensorflow,代码行数:53,代码来源:tfexample_decoder.py


示例13: testCase_multiple_matches_exclusive

 def testCase_multiple_matches_exclusive(self):
   x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
   conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
                 (math_ops.equal(x, 2), lambda: constant_op.constant(4)),
                 (math_ops.equal(x, 2), lambda: constant_op.constant(6))]
   default = lambda: constant_op.constant(8)
   output = control_flow_ops.case(conditions, default, exclusive=True)
   with self.test_session() as sess:
     self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
     self.assertEqual(sess.run(output, feed_dict={x: 3}), 8)
     with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
       sess.run(output, feed_dict={x: 2})
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:12,代码来源:control_flow_ops_test.py


示例14: _testShape

  def _testShape(self, fn_true, fn_false, expected_shape,
                 strict=False):
    condition = array_ops.placeholder(dtypes.bool)
    output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
                                        strict=strict)
    self.assertEqual(_RawNestedShape(_GetNestedShape(output_cond)),
                     _RawNestedShape(expected_shape))

    output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
                                        strict=strict)
    self.assertEqual(_RawNestedShape(_GetNestedShape(output_case)),
                     _RawNestedShape(expected_shape))
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:12,代码来源:control_flow_ops_test.py


示例15: test_cov_update_thunks

  def test_cov_update_thunks(self):
    """Ensures covariance update ops run once per global_step."""
    with self._graph.as_default(), self.test_session() as sess:
      fisher_estimator = estimator.FisherEstimatorRoundRobin(
          variables=[self.weights],
          layer_collection=self.layer_collection,
          damping=0.2,
          cov_ema_decay=0.0)

      # Construct an op that executes one covariance update per step.
      global_step = training_util.get_or_create_global_step()
      (cov_variable_thunks, cov_update_op_thunks, _,
       _) = fisher_estimator.create_ops_and_vars_thunks()
      for thunk in cov_variable_thunks:
        thunk()
      cov_matrices = [
          fisher_factor.get_cov()
          for fisher_factor in self.layer_collection.get_factors()
      ]
      cov_update_op = control_flow_ops.case(
          [(math_ops.equal(global_step, i), thunk)
           for i, thunk in enumerate(cov_update_op_thunks)])
      increment_global_step = global_step.assign_add(1)

      sess.run(variables.global_variables_initializer())
      initial_cov_values = sess.run(cov_matrices)

      # Ensure there's one update per covariance matrix.
      self.assertEqual(len(cov_matrices), len(cov_update_op_thunks))

      # Test is no-op if only 1 covariance matrix.
      assert len(cov_matrices) > 1

      for i in range(len(cov_matrices)):
        # Compare new and old covariance values
        new_cov_values = sess.run(cov_matrices)
        is_cov_equal = [
            np.allclose(initial_cov_value, new_cov_value)
            for (initial_cov_value,
                 new_cov_value) in zip(initial_cov_values, new_cov_values)
        ]
        num_cov_equal = sum(is_cov_equal)

        # Ensure exactly one covariance matrix changes per step.
        self.assertEqual(num_cov_equal, len(cov_matrices) - i)

        # Run all covariance update ops.
        sess.run(cov_update_op)
        sess.run(increment_global_step)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:49,代码来源:estimator_test.py


示例16: control_map_fn

    def control_map_fn(x, y):

      def multiply():
        return x * 2

      def divide():
        return x // 2

      pred_fn_pairs = {
          math_ops.logical_or(math_ops.equal(y, 2), math_ops.equal(y, 3)):
              divide,
      }

      return control_flow_ops.case(
          pred_fn_pairs, default=multiply, exclusive=True)
开发者ID:bunbutter,项目名称:tensorflow,代码行数:15,代码来源:map_test.py


示例17: testCase_withoutDefault

 def testCase_withoutDefault(self):
   x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
   conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
                 (math_ops.equal(x, 2), lambda: constant_op.constant(4)),
                 (math_ops.equal(x, 3), lambda: constant_op.constant(6))]
   output = control_flow_ops.case(conditions, exclusive=True)
   with self.test_session() as sess:
     self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
     self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
     self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
     with self.assertRaisesRegexp(
         errors.InvalidArgumentError,
         r"\[None of the conditions evaluated as True. "
         r"Conditions: \(Equal:0, Equal_1:0, Equal_2:0\), Values:\] "
         r"\[0 0 0\]"):
       sess.run(output, feed_dict={x: 4})
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:16,代码来源:control_flow_ops_test.py


示例18: conditional_decoding

                def conditional_decoding(keys_to_tensors):
                    """See base class."""
                    image_buffer = keys_to_tensors['image/encoded']
                    image_format = keys_to_tensors['image/format']

                    def decode_png():
                        return image_ops.decode_png(image_buffer, 3)

                    def decode_jpg():
                        return image_ops.decode_jpeg(image_buffer, 3)

                    image = control_flow_ops.case(
                        {math_ops.equal(image_format, 'png'): decode_png},
                        default=decode_jpg,
                        exclusive=True)
                    image = array_ops.reshape(image, image_shape)
                    return image
开发者ID:AlexMikhalev,项目名称:polyaxon,代码行数:17,代码来源:test_data_decoder.py


示例19: decayed_lr

  def decayed_lr(x, boundaries, values, name):
    """Helper to recompute learning rate; most helpful in eager-mode."""
    with ops.name_scope(name, "PiecewiseConstant",
                        [x, boundaries, values, name]) as name:
      boundaries = ops.convert_n_to_tensor(boundaries)
      values = ops.convert_n_to_tensor(values)
      x_recomp = ops.convert_to_tensor(x)
      # Avoid explicit conversion to x's dtype. This could result in faulty
      # comparisons, for example if floats are converted to integers.
      for i, b in enumerate(boundaries):
        if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
          # We can promote int32 boundaries to int64 without loss of precision.
          # This covers the most common case where the user passes in boundaries
          # as an array of Python integers.
          if (b.dtype.base_dtype == dtypes.int32 and
              x_recomp.dtype.base_dtype == dtypes.int64):
            b = math_ops.cast(b, x_recomp.dtype.base_dtype)
            boundaries[i] = b
          else:
            raise ValueError(
                "Boundaries (%s) must have the same dtype as x (%s)." %
                (b.dtype.base_dtype, x_recomp.dtype.base_dtype))
      # TODO(rdipietro): Ensure that boundaries' elements strictly increases.
      for v in values[1:]:
        if v.dtype.base_dtype != values[0].dtype.base_dtype:
          raise ValueError(
              "Values must have elements all with the same dtype (%s vs %s)." %
              (values[0].dtype.base_dtype, v.dtype.base_dtype))
      pred_fn_pairs = []
      pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0]))
      pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1]))
      for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
        # Need to bind v here; can do this with lambda v=v: ...
        pred = (x_recomp > low) & (x_recomp <= high)
        pred_fn_pairs.append((pred, lambda v=v: v))

      # The default isn't needed here because our conditions are mutually
      # exclusive and exhaustive, but tf.case requires it.
      default = lambda: values[0]
      return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:40,代码来源:learning_rate_decay_v2.py


示例20: _decode

  def _decode(self, image_buffer, image_format):
    """Decodes the image buffer.

    Args:
      image_buffer: T tensor representing the encoded image tensor.
      image_format: The image format for the image in `image_buffer`.

    Returns:
      A decoder image.
    """
    def decode_png():
      return image_ops.decode_png(image_buffer, self._channels)
    def decode_jpg():
      return image_ops.decode_jpeg(image_buffer, self._channels)

    image = control_flow_ops.case({
        math_ops.equal(image_format, 'png'): decode_png,
    }, default=decode_jpg, exclusive=True)

    if self._shape is not None:
      image = array_ops.reshape(image, self._shape)
    return image
开发者ID:285219011,项目名称:hello-world,代码行数:22,代码来源:tfexample_decoder.py



注:本文中的tensorflow.python.ops.control_flow_ops.case函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python control_flow_ops.cond函数代码示例发布时间:2022-05-27
下一篇:
Python constant_op.constant函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap