本文整理汇总了Python中tensorflow.python.ops.control_flow_ops.case函数的典型用法代码示例。如果您正苦于以下问题:Python case函数的具体用法?Python case怎么用?Python case使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了case函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: _testReturnValues
def _testReturnValues(self, fn_true, fn_false, expected_value_true,
expected_value_false, strict=False,
check_cond=True, feed_dict=None):
if feed_dict is None: feed_dict = {}
condition = array_ops.placeholder(dtypes.bool)
output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
strict=strict)
output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
strict=strict)
with self.test_session() as sess:
variables.global_variables_initializer().run()
true_feed_dict = {condition: True}
true_feed_dict.update(feed_dict)
result_cond, result_case = sess.run([output_cond, output_case],
feed_dict=true_feed_dict)
self.assertAllEqualNested(result_cond, expected_value_true)
if check_cond:
self.assertAllEqualNested(result_case, expected_value_true)
false_feed_dict = {condition: False}
false_feed_dict.update(feed_dict)
result_cond, result_case = sess.run([output_cond, output_case],
feed_dict=false_feed_dict)
self.assertAllEqualNested(result_cond, expected_value_false)
if check_cond:
self.assertAllEqualNested(result_case, expected_value_false)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:27,代码来源:control_flow_ops_test.py
示例2: _decode
def _decode(self, image_buffer, image_format):
"""Decodes the image buffer.
Args:
image_buffer: T tensor representing the encoded image tensor.
image_format: The image format for the image in `image_buffer`.
Returns:
A decoder image.
"""
def decode_png():
return image_ops.decode_png(image_buffer, self._channels)
def decode_raw():
return parsing_ops.decode_raw(image_buffer, dtypes.uint8)
def decode_jpg():
return image_ops.decode_jpeg(image_buffer, self._channels)
image = control_flow_ops.case({
math_ops.logical_or(math_ops.equal(image_format, 'png'),
math_ops.equal(image_format, 'PNG')): decode_png,
math_ops.logical_or(math_ops.equal(image_format, 'raw'),
math_ops.equal(image_format, 'RAW')): decode_raw,
}, default=decode_jpg, exclusive=True)
image.set_shape([None, None, self._channels])
if self._shape is not None:
image = array_ops.reshape(image, self._shape)
return image
开发者ID:821760408-sp,项目名称:tensorflow,代码行数:29,代码来源:tfexample_decoder.py
示例3: _decode
def _decode(self, image_buffer, image_format):
"""Decodes the image buffer.
Args:
image_buffer: The tensor representing the encoded image tensor.
image_format: The image format for the image in `image_buffer`. If image
format is `raw`, all images are expected to be in this format, otherwise
this op can decode a mix of `jpg` and `png` formats.
Returns:
A tensor that represents decoded image of self._shape, or
(?, ?, self._channels) if self._shape is not specified.
"""
def decode_image():
"""Decodes a png or jpg based on the headers."""
return image_ops.decode_image(image_buffer, self._channels)
def decode_raw():
"""Decodes a raw image."""
return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)
pred_fn_pairs = {
math_ops.logical_or(
math_ops.equal(image_format, 'raw'),
math_ops.equal(image_format, 'RAW')): decode_raw,
}
image = control_flow_ops.case(
pred_fn_pairs, default=decode_image, exclusive=True)
image.set_shape([None, None, self._channels])
if self._shape is not None:
image = array_ops.reshape(image, self._shape)
return image
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:34,代码来源:tfexample_decoder.py
示例4: rot90
def rot90(image, k=1, name=None):
"""Rotate an image counter-clockwise by 90 degrees.
Args:
image: A 3-D tensor of shape `[height, width, channels]`.
k: A scalar integer. The number of times the image is rotated by 90 degrees.
name: A name for this operation (optional).
Returns:
A rotated 3-D tensor of the same type and shape as `image`.
"""
with ops.name_scope(name, 'rot90', [image, k]) as scope:
image = ops.convert_to_tensor(image, name='image')
_Check3DImage(image, require_static=False)
k = ops.convert_to_tensor(k, dtype=dtypes.int32, name='k')
k.get_shape().assert_has_rank(0)
k = math_ops.mod(k, 4)
def _rot90():
return array_ops.transpose(array_ops.reverse_v2(image, [1]),
[1, 0, 2])
def _rot180():
return array_ops.reverse_v2(image, [0, 1])
def _rot270():
return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]),
[1])
cases = [(math_ops.equal(k, 1), _rot90),
(math_ops.equal(k, 2), _rot180),
(math_ops.equal(k, 3), _rot270)]
ret = control_flow_ops.case(cases, default=lambda: image, exclusive=True,
name=scope)
ret.set_shape([None, None, image.get_shape()[2]])
return ret
开发者ID:kdavis-mozilla,项目名称:tensorflow,代码行数:34,代码来源:image_ops_impl.py
示例5: test_inv_update_thunks
def test_inv_update_thunks(self):
"""Ensures inverse update ops run once per global_step."""
with self._graph.as_default(), self.test_session() as sess:
fisher_estimator = estimator.FisherEstimator(
damping_fn=lambda: 0.2,
variables=[self.weights],
layer_collection=self.layer_collection,
cov_ema_decay=0.0)
# Construct op that updates one inverse per global step.
global_step = training_util.get_or_create_global_step()
inv_matrices = [
matrix
for fisher_factor in self.layer_collection.get_factors()
for matrix in fisher_factor._inverses_by_damping.values()
]
inv_update_op_thunks = fisher_estimator.inv_update_thunks
inv_update_op = control_flow_ops.case(
[(math_ops.equal(global_step, i), thunk)
for i, thunk in enumerate(inv_update_op_thunks)])
increment_global_step = global_step.assign_add(1)
sess.run(variables.global_variables_initializer())
initial_inv_values = sess.run(inv_matrices)
# Ensure there's one update per inverse matrix. This is true as long as
# there's no fan-in/fan-out or parameter re-use.
self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))
# Test is no-op if only 1 invariance matrix.
assert len(inv_matrices) > 1
# Assign each covariance matrix a value other than the identity. This
# ensures that the inverse matrices are updated to something different as
# well.
cov_matrices = [
fisher_factor.get_cov()
for fisher_factor in self.layer_collection.get_factors()
]
sess.run([
cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0])))
for cov_matrix in cov_matrices
])
for i in range(len(inv_matrices)):
# Compare new and old inverse values
new_inv_values = sess.run(inv_matrices)
is_inv_equal = [
np.allclose(initial_inv_value, new_inv_value)
for (initial_inv_value,
new_inv_value) in zip(initial_inv_values, new_inv_values)
]
num_inv_equal = sum(is_inv_equal)
# Ensure exactly one inverse matrix changes per step.
self.assertEqual(num_inv_equal, len(inv_matrices) - i)
# Run all inverse update ops.
sess.run(inv_update_op)
sess.run(increment_global_step)
开发者ID:QiangCai,项目名称:tensorflow,代码行数:60,代码来源:estimator_test.py
示例6: testCase_dict
def testCase_dict(self):
x = constant_op.constant(2)
conditions = {
math_ops.equal(x, 1): lambda: constant_op.constant(2),
math_ops.equal(x, 2): lambda: constant_op.constant(4)
}
output = control_flow_ops.case(conditions, exclusive=True)
self.assertEqual(4, self.evaluate(output))
开发者ID:bunbutter,项目名称:tensorflow,代码行数:8,代码来源:control_flow_ops_test.py
示例7: testCase_withoutDefault_oneCondition
def testCase_withoutDefault_oneCondition(self):
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2))]
output = control_flow_ops.case(conditions, exclusive=True)
with self.test_session() as sess:
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
sess.run(output, feed_dict={x: 4})
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:8,代码来源:control_flow_ops_test.py
示例8: piecewise_constant
def piecewise_constant(x, boundaries, values, name=None):
""" Piecewise constant from boundaries and interval values.
Example: use a learning rate that's 1.0 for the first 100000 steps, 0.5
for steps 100001 to 110000, and 0.1 for any additional steps.
```python
global_step = tf.Variable(0, trainable=False)
boundaries = [100000, 110000]
values = [1.0, 0.5, 0.1]
learning_rate = tf.train.piecewise_constant(global_step, boundaries, values)
# Later, whenever we perform an optimization step, we increment global_step.
```
Args:
x: A 0-D scalar `Tensor`. Must be one of the following types: `float32`,
`float64`, `uint8`, `int8`, `int16`, `int32`, `int64`.
boundaries: A list of `Tensor`s or `int`s or `float`s with strictly
increasing entries, and with all elements having the same type as `x`.
values: A list of `Tensor`s or float`s or `int`s that specifies the values
for the intervals defined by `boundaries`. It should have one more element
than `boundaries`, and all elements should have the same type.
name: A string. Optional name of the operation. Defaults to
'PiecewiseConstant'.
Returns:
A 0-D Tensor. Its value is `values[0]` when `x <= boundaries[0]`,
`values[1]` when `x > boundaries[0]` and `x <= boundaries[1]`, ...,
and values[-1] when `x > boundaries[-1]`.
"""
with ops.name_scope(name, 'PiecewiseConstant',
[x, boundaries, values, name]) as name:
x = ops.convert_to_tensor(x)
# Avoid explicit conversion to x's dtype. This could result in faulty
# comparisons, for example if floats are converted to integers.
boundaries = ops.convert_n_to_tensor(boundaries)
if not all(b.dtype == x.dtype for b in boundaries):
raise ValueError('boundaries must have the same dtype as x.')
# TODO(rdipietro): Ensure that boundaries' elements are strictly increasing.
values = ops.convert_n_to_tensor(values)
if not all(v.dtype == values[0].dtype for v in values):
raise ValueError('values must have elements all with the same dtype.')
pred_fn_pairs = {}
pred_fn_pairs[x <= boundaries[0]] = lambda: values[0]
pred_fn_pairs[x > boundaries[-1]] = lambda: values[-1]
for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
# Need to bind v here; can do this with lambda v=v: ...
pred = (x > low) & (x <= high)
pred_fn_pairs[pred] = lambda v=v: v
# The default isn't needed here because our conditions are mutually
# exclusive and exhaustive, but tf.case requires it.
default = lambda: values[0]
return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
开发者ID:AriaAsuka,项目名称:tensorflow,代码行数:57,代码来源:learning_rate_decay.py
示例9: testCase_withDefault
def testCase_withDefault(self):
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
(math_ops.equal(x, 2), lambda: constant_op.constant(4))]
default = lambda: constant_op.constant(6)
output = control_flow_ops.case(conditions, default, exclusive=True)
with self.test_session() as sess:
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:10,代码来源:control_flow_ops_test.py
示例10: _decode
def _decode(self, image_buffer, image_format):
"""Decodes the image buffer.
Args:
image_buffer: The tensor representing the encoded image tensor.
image_format: The image format for the image in `image_buffer`. If image
format is `raw`, all images are expected to be in this format, otherwise
this op can decode a mix of `jpg` and `png` formats.
Returns:
A tensor that represents decoded image of self._shape, or
(?, ?, self._channels) if self._shape is not specified.
"""
def decode_image():
"""Decodes a image based on the headers."""
return math_ops.cast(
image_ops.decode_image(image_buffer, channels=self._channels),
self._dtype)
def decode_jpeg():
"""Decodes a jpeg image with specified '_dct_method'."""
return math_ops.cast(
image_ops.decode_jpeg(
image_buffer,
channels=self._channels,
dct_method=self._dct_method), self._dtype)
def check_jpeg():
"""Checks if an image is jpeg."""
# For jpeg, we directly use image_ops.decode_jpeg rather than decode_image
# in order to feed the jpeg specify parameter 'dct_method'.
return control_flow_ops.cond(
image_ops.is_jpeg(image_buffer),
decode_jpeg,
decode_image,
name='cond_jpeg')
def decode_raw():
"""Decodes a raw image."""
return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)
pred_fn_pairs = {
math_ops.logical_or(
math_ops.equal(image_format, 'raw'),
math_ops.equal(image_format, 'RAW')): decode_raw,
}
image = control_flow_ops.case(
pred_fn_pairs, default=check_jpeg, exclusive=True)
image.set_shape([None, None, self._channels])
if self._shape is not None:
image = array_ops.reshape(image, self._shape)
return image
开发者ID:Albert-Z-Guo,项目名称:tensorflow,代码行数:55,代码来源:tfexample_decoder.py
示例11: test_singleton_strict
def test_singleton_strict(self):
fn_tensor = lambda: constant_op.constant(1)
fn_list = lambda: [constant_op.constant(2)]
fn_tuple = lambda: (constant_op.constant(3),)
with self.assertRaises(ValueError):
control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_list,
strict=True)
with self.assertRaises(TypeError):
control_flow_ops.cond(constant_op.constant(True), fn_list, fn_tuple,
strict=True)
with self.assertRaises(ValueError):
control_flow_ops.case([(constant_op.constant(True), fn_tensor)], fn_list,
strict=True)
with self.assertRaises(TypeError):
control_flow_ops.case([(constant_op.constant(True), fn_list)], fn_tuple,
strict=True)
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:20,代码来源:control_flow_ops_test.py
示例12: _decode
def _decode(self, image_buffer, image_format):
"""Decodes the image buffer.
Args:
image_buffer: The tensor representing the encoded image tensor.
image_format: The image format for the image in `image_buffer`.
Returns:
A tensor that represents decoded image of self._shape, or
(?, ?, self._channels) if self._shape is not specified.
"""
def decode_png():
return image_ops.decode_png(
image_buffer, self._channels, dtype=self._dtype)
def decode_raw():
return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)
def decode_jpg():
if self._dtype != dtypes.uint8:
raise ValueError(
'jpeg decoder can only be used to decode to tf.uint8 but %s was '
'requested for a jpeg image.' % self._dtype)
return image_ops.decode_jpeg(image_buffer, self._channels)
# For RGBA images JPEG is not a valid decoder option.
if self._channels > 3:
pred_fn_pairs = {
math_ops.logical_or(
math_ops.equal(image_format, 'raw'),
math_ops.equal(image_format, 'RAW')): decode_raw,
}
default_decoder = decode_png
else:
pred_fn_pairs = {
math_ops.logical_or(
math_ops.equal(image_format, 'png'),
math_ops.equal(image_format, 'PNG')): decode_png,
math_ops.logical_or(
math_ops.equal(image_format, 'raw'),
math_ops.equal(image_format, 'RAW')): decode_raw,
}
default_decoder = decode_jpg
image = control_flow_ops.case(
pred_fn_pairs, default=default_decoder, exclusive=True)
image.set_shape([None, None, self._channels])
if self._shape is not None:
image = array_ops.reshape(image, self._shape)
return image
开发者ID:LUTAN,项目名称:tensorflow,代码行数:53,代码来源:tfexample_decoder.py
示例13: testCase_multiple_matches_exclusive
def testCase_multiple_matches_exclusive(self):
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
(math_ops.equal(x, 2), lambda: constant_op.constant(4)),
(math_ops.equal(x, 2), lambda: constant_op.constant(6))]
default = lambda: constant_op.constant(8)
output = control_flow_ops.case(conditions, default, exclusive=True)
with self.test_session() as sess:
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
self.assertEqual(sess.run(output, feed_dict={x: 3}), 8)
with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
sess.run(output, feed_dict={x: 2})
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:12,代码来源:control_flow_ops_test.py
示例14: _testShape
def _testShape(self, fn_true, fn_false, expected_shape,
strict=False):
condition = array_ops.placeholder(dtypes.bool)
output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
strict=strict)
self.assertEqual(_RawNestedShape(_GetNestedShape(output_cond)),
_RawNestedShape(expected_shape))
output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
strict=strict)
self.assertEqual(_RawNestedShape(_GetNestedShape(output_case)),
_RawNestedShape(expected_shape))
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:12,代码来源:control_flow_ops_test.py
示例15: test_cov_update_thunks
def test_cov_update_thunks(self):
"""Ensures covariance update ops run once per global_step."""
with self._graph.as_default(), self.test_session() as sess:
fisher_estimator = estimator.FisherEstimatorRoundRobin(
variables=[self.weights],
layer_collection=self.layer_collection,
damping=0.2,
cov_ema_decay=0.0)
# Construct an op that executes one covariance update per step.
global_step = training_util.get_or_create_global_step()
(cov_variable_thunks, cov_update_op_thunks, _,
_) = fisher_estimator.create_ops_and_vars_thunks()
for thunk in cov_variable_thunks:
thunk()
cov_matrices = [
fisher_factor.get_cov()
for fisher_factor in self.layer_collection.get_factors()
]
cov_update_op = control_flow_ops.case(
[(math_ops.equal(global_step, i), thunk)
for i, thunk in enumerate(cov_update_op_thunks)])
increment_global_step = global_step.assign_add(1)
sess.run(variables.global_variables_initializer())
initial_cov_values = sess.run(cov_matrices)
# Ensure there's one update per covariance matrix.
self.assertEqual(len(cov_matrices), len(cov_update_op_thunks))
# Test is no-op if only 1 covariance matrix.
assert len(cov_matrices) > 1
for i in range(len(cov_matrices)):
# Compare new and old covariance values
new_cov_values = sess.run(cov_matrices)
is_cov_equal = [
np.allclose(initial_cov_value, new_cov_value)
for (initial_cov_value,
new_cov_value) in zip(initial_cov_values, new_cov_values)
]
num_cov_equal = sum(is_cov_equal)
# Ensure exactly one covariance matrix changes per step.
self.assertEqual(num_cov_equal, len(cov_matrices) - i)
# Run all covariance update ops.
sess.run(cov_update_op)
sess.run(increment_global_step)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:49,代码来源:estimator_test.py
示例16: control_map_fn
def control_map_fn(x, y):
def multiply():
return x * 2
def divide():
return x // 2
pred_fn_pairs = {
math_ops.logical_or(math_ops.equal(y, 2), math_ops.equal(y, 3)):
divide,
}
return control_flow_ops.case(
pred_fn_pairs, default=multiply, exclusive=True)
开发者ID:bunbutter,项目名称:tensorflow,代码行数:15,代码来源:map_test.py
示例17: testCase_withoutDefault
def testCase_withoutDefault(self):
x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
(math_ops.equal(x, 2), lambda: constant_op.constant(4)),
(math_ops.equal(x, 3), lambda: constant_op.constant(6))]
output = control_flow_ops.case(conditions, exclusive=True)
with self.test_session() as sess:
self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
r"\[None of the conditions evaluated as True. "
r"Conditions: \(Equal:0, Equal_1:0, Equal_2:0\), Values:\] "
r"\[0 0 0\]"):
sess.run(output, feed_dict={x: 4})
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:16,代码来源:control_flow_ops_test.py
示例18: conditional_decoding
def conditional_decoding(keys_to_tensors):
"""See base class."""
image_buffer = keys_to_tensors['image/encoded']
image_format = keys_to_tensors['image/format']
def decode_png():
return image_ops.decode_png(image_buffer, 3)
def decode_jpg():
return image_ops.decode_jpeg(image_buffer, 3)
image = control_flow_ops.case(
{math_ops.equal(image_format, 'png'): decode_png},
default=decode_jpg,
exclusive=True)
image = array_ops.reshape(image, image_shape)
return image
开发者ID:AlexMikhalev,项目名称:polyaxon,代码行数:17,代码来源:test_data_decoder.py
示例19: decayed_lr
def decayed_lr(x, boundaries, values, name):
"""Helper to recompute learning rate; most helpful in eager-mode."""
with ops.name_scope(name, "PiecewiseConstant",
[x, boundaries, values, name]) as name:
boundaries = ops.convert_n_to_tensor(boundaries)
values = ops.convert_n_to_tensor(values)
x_recomp = ops.convert_to_tensor(x)
# Avoid explicit conversion to x's dtype. This could result in faulty
# comparisons, for example if floats are converted to integers.
for i, b in enumerate(boundaries):
if b.dtype.base_dtype != x_recomp.dtype.base_dtype:
# We can promote int32 boundaries to int64 without loss of precision.
# This covers the most common case where the user passes in boundaries
# as an array of Python integers.
if (b.dtype.base_dtype == dtypes.int32 and
x_recomp.dtype.base_dtype == dtypes.int64):
b = math_ops.cast(b, x_recomp.dtype.base_dtype)
boundaries[i] = b
else:
raise ValueError(
"Boundaries (%s) must have the same dtype as x (%s)." %
(b.dtype.base_dtype, x_recomp.dtype.base_dtype))
# TODO(rdipietro): Ensure that boundaries' elements strictly increases.
for v in values[1:]:
if v.dtype.base_dtype != values[0].dtype.base_dtype:
raise ValueError(
"Values must have elements all with the same dtype (%s vs %s)." %
(values[0].dtype.base_dtype, v.dtype.base_dtype))
pred_fn_pairs = []
pred_fn_pairs.append((x_recomp <= boundaries[0], lambda: values[0]))
pred_fn_pairs.append((x_recomp > boundaries[-1], lambda: values[-1]))
for low, high, v in zip(boundaries[:-1], boundaries[1:], values[1:-1]):
# Need to bind v here; can do this with lambda v=v: ...
pred = (x_recomp > low) & (x_recomp <= high)
pred_fn_pairs.append((pred, lambda v=v: v))
# The default isn't needed here because our conditions are mutually
# exclusive and exhaustive, but tf.case requires it.
default = lambda: values[0]
return control_flow_ops.case(pred_fn_pairs, default, exclusive=True)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:40,代码来源:learning_rate_decay_v2.py
示例20: _decode
def _decode(self, image_buffer, image_format):
"""Decodes the image buffer.
Args:
image_buffer: T tensor representing the encoded image tensor.
image_format: The image format for the image in `image_buffer`.
Returns:
A decoder image.
"""
def decode_png():
return image_ops.decode_png(image_buffer, self._channels)
def decode_jpg():
return image_ops.decode_jpeg(image_buffer, self._channels)
image = control_flow_ops.case({
math_ops.equal(image_format, 'png'): decode_png,
}, default=decode_jpg, exclusive=True)
if self._shape is not None:
image = array_ops.reshape(image, self._shape)
return image
开发者ID:285219011,项目名称:hello-world,代码行数:22,代码来源:tfexample_decoder.py
注:本文中的tensorflow.python.ops.control_flow_ops.case函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论