本文整理汇总了Python中tensorflow.python.ops.check_ops.assert_same_float_dtype函数的典型用法代码示例。如果您正苦于以下问题:Python assert_same_float_dtype函数的具体用法?Python assert_same_float_dtype怎么用?Python assert_same_float_dtype使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了assert_same_float_dtype函数的15个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: __init__
def __init__(self,
loc=0.,
scale=1.,
validate_args=False,
name="gumbel"):
"""Instantiates the `Gumbel` bijector.
Args:
loc: Float-like `Tensor` that is the same dtype and is
broadcastable with `scale`.
This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.
scale: Positive Float-like `Tensor` that is the same dtype and is
broadcastable with `loc`.
This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.
validate_args: Python `bool` indicating whether arguments should be
checked for correctness.
name: Python `str` name given to ops managed by this object.
"""
self._graph_parents = []
self._name = name
self._validate_args = validate_args
with self._name_scope("init", values=[loc, scale]):
self._loc = ops.convert_to_tensor(loc, name="loc")
self._scale = ops.convert_to_tensor(scale, name="scale")
check_ops.assert_same_float_dtype([self._loc, self._scale])
if validate_args:
self._scale = control_flow_ops.with_dependencies([
check_ops.assert_positive(
self._scale, message="Argument scale was not positive")
], self._scale)
super(Gumbel, self).__init__(
validate_args=validate_args,
forward_min_event_ndims=0,
name=name)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:35,代码来源:gumbel.py
示例2: __init__
def __init__(self,
skewness=0.,
tailweight=1.,
event_ndims=0,
validate_args=False,
name="sinh_arcsinh"):
"""Instantiates the `SinhArcsinh` bijector.
Args:
skewness: Skewness parameter. Float-type `Tensor`.
tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as
`skewness`
and broadcastable `shape`.
event_ndims: Python scalar indicating the number of dimensions associated
with a particular draw from the distribution.
validate_args: Python `bool` indicating whether arguments should be
checked for correctness.
name: Python `str` name given to ops managed by this object.
"""
self._graph_parents = []
self._name = name
self._validate_args = validate_args
with self._name_scope("init", values=[skewness, tailweight]):
self._skewness = ops.convert_to_tensor(skewness, name="skewness")
self._tailweight = ops.convert_to_tensor(tailweight, name="tailweight")
check_ops.assert_same_float_dtype([self._skewness, self._tailweight])
if validate_args:
self._tailweight = control_flow_ops.with_dependencies([
check_ops.assert_positive(
self._tailweight,
message="Argument tailweight was not positive")
], self._tailweight)
super(SinhArcsinh, self).__init__(
event_ndims=event_ndims, validate_args=validate_args, name=name)
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:34,代码来源:sinh_arcsinh_impl.py
示例3: _maybe_assert_valid_sample
def _maybe_assert_valid_sample(self, x):
check_ops.assert_same_float_dtype(tensors=[x], dtype=self.dtype)
if not self.validate_args:
return x
return control_flow_ops.with_dependencies([
check_ops.assert_positive(x),
], x)
开发者ID:aritratony,项目名称:tensorflow,代码行数:7,代码来源:gamma.py
示例4: __init__
def __init__(self,
df,
loc,
scale,
validate_args=False,
allow_nan_stats=True,
name="StudentT"):
"""Construct Student's t distributions.
The distributions have degree of freedom `df`, mean `loc`, and scale
`scale`.
The parameters `df`, `loc`, and `scale` must be shaped in a way that
supports broadcasting (e.g. `df + loc + scale` is a valid operation).
Args:
df: Floating-point `Tensor`. The degrees of freedom of the
distribution(s). `df` must contain only positive values.
loc: Floating-point `Tensor`. The mean(s) of the distribution(s).
scale: Floating-point `Tensor`. The scaling factor(s) for the
distribution(s). Note that `scale` is not technically the standard
deviation of this distribution but has semantics more similar to
standard deviation than variance.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`,
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
indicate the result is undefined. When `False`, an exception is raised
if one or more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
Raises:
TypeError: if loc and scale are different dtypes.
"""
parameters = dict(locals())
with ops.name_scope(name, values=[df, loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(df)]
if validate_args else []):
self._df = array_ops.identity(df, name="df")
self._loc = array_ops.identity(loc, name="loc")
self._scale = array_ops.identity(scale, name="scale")
check_ops.assert_same_float_dtype(
(self._df, self._loc, self._scale))
super(StudentT, self).__init__(
dtype=self._scale.dtype,
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._df, self._loc, self._scale],
name=name)
开发者ID:daiwk,项目名称:tensorflow,代码行数:53,代码来源:student_t.py
示例5: __init__
def __init__(self,
concentration,
rate,
validate_args=False,
allow_nan_stats=True,
name="InverseGamma"):
"""Construct InverseGamma with `concentration` and `rate` parameters.
The parameters `concentration` and `rate` must be shaped in a way that
supports broadcasting (e.g. `concentration + rate` is a valid operation).
Args:
concentration: Floating point tensor, the concentration params of the
distribution(s). Must contain only positive values.
rate: Floating point tensor, the inverse scale params of the
distribution(s). Must contain only positive values.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
result is undefined. When `False`, an exception is raised if one or
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
Raises:
TypeError: if `concentration` and `rate` are different dtypes.
"""
parameters = dict(locals())
with ops.name_scope(name, values=[concentration, rate]) as name:
with ops.control_dependencies([
check_ops.assert_positive(concentration),
check_ops.assert_positive(rate),
] if validate_args else []):
self._concentration = array_ops.identity(
concentration, name="concentration")
self._rate = array_ops.identity(rate, name="rate")
check_ops.assert_same_float_dtype(
[self._concentration, self._rate])
super(InverseGamma, self).__init__(
dtype=self._concentration.dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
reparameterization_type=distribution.NOT_REPARAMETERIZED,
parameters=parameters,
graph_parents=[self._concentration,
self._rate],
name=name)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:50,代码来源:inverse_gamma.py
示例6: test_assert_same_float_dtype
def test_assert_same_float_dtype(self):
self.assertIs(dtypes.float32,
check_ops.assert_same_float_dtype(None, None))
self.assertIs(dtypes.float32, check_ops.assert_same_float_dtype([], None))
self.assertIs(dtypes.float32,
check_ops.assert_same_float_dtype([], dtypes.float32))
self.assertIs(dtypes.float32,
check_ops.assert_same_float_dtype(None, dtypes.float32))
self.assertIs(dtypes.float32,
check_ops.assert_same_float_dtype([None, None], None))
self.assertIs(
dtypes.float32,
check_ops.assert_same_float_dtype([None, None], dtypes.float32))
const_float = constant_op.constant(3.0, dtype=dtypes.float32)
self.assertIs(
dtypes.float32,
check_ops.assert_same_float_dtype([const_float], dtypes.float32))
self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
[const_float], dtypes.int32)
sparse_float = sparse_tensor.SparseTensor(
constant_op.constant([[111], [232]], dtypes.int64),
constant_op.constant([23.4, -43.2], dtypes.float32),
constant_op.constant([500], dtypes.int64))
self.assertIs(dtypes.float32,
check_ops.assert_same_float_dtype([sparse_float],
dtypes.float32))
self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
[sparse_float], dtypes.int32)
self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
[const_float, None, sparse_float], dtypes.float64)
self.assertIs(dtypes.float32,
check_ops.assert_same_float_dtype(
[const_float, sparse_float]))
self.assertIs(dtypes.float32,
check_ops.assert_same_float_dtype(
[const_float, sparse_float], dtypes.float32))
const_int = constant_op.constant(3, dtype=dtypes.int32)
self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
[sparse_float, const_int])
self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
[sparse_float, const_int], dtypes.int32)
self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
[sparse_float, const_int], dtypes.float32)
self.assertRaises(ValueError, check_ops.assert_same_float_dtype,
[const_int])
开发者ID:1000sprites,项目名称:tensorflow,代码行数:49,代码来源:check_ops_test.py
示例7: __init__
def __init__(self,
concentration1=None,
concentration0=None,
validate_args=False,
allow_nan_stats=True,
name="Beta"):
"""Initialize a batch of Beta distributions.
Args:
concentration1: Positive floating-point `Tensor` indicating mean
number of successes; aka "alpha". Implies `self.dtype` and
`self.batch_shape`, i.e.,
`concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`.
concentration0: Positive floating-point `Tensor` indicating mean
number of failures; aka "beta". Otherwise has same semantics as
`concentration1`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
result is undefined. When `False`, an exception is raised if one or
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
"""
parameters = dict(locals())
with ops.name_scope(name, values=[concentration1, concentration0]) as name:
self._concentration1 = self._maybe_assert_valid_concentration(
ops.convert_to_tensor(concentration1, name="concentration1"),
validate_args)
self._concentration0 = self._maybe_assert_valid_concentration(
ops.convert_to_tensor(concentration0, name="concentration0"),
validate_args)
check_ops.assert_same_float_dtype([
self._concentration1, self._concentration0])
self._total_concentration = self._concentration1 + self._concentration0
super(Beta, self).__init__(
dtype=self._total_concentration.dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
parameters=parameters,
graph_parents=[self._concentration1,
self._concentration0,
self._total_concentration],
name=name)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:47,代码来源:beta.py
示例8: __init__
def __init__(self,
scale=1.,
concentration=1.,
event_ndims=0,
validate_args=False,
name="weibull"):
"""Instantiates the `Weibull` bijector.
Args:
scale: Positive Float-type `Tensor` that is the same dtype and is
broadcastable with `concentration`.
This is `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`.
concentration: Positive Float-type `Tensor` that is the same dtype and is
broadcastable with `scale`.
This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`.
event_ndims: Python scalar indicating the number of dimensions associated
with a particular draw from the distribution.
validate_args: Python `bool` indicating whether arguments should be
checked for correctness.
name: Python `str` name given to ops managed by this object.
"""
self._graph_parents = []
self._name = name
self._validate_args = validate_args
with self._name_scope("init", values=[scale, concentration]):
self._scale = ops.convert_to_tensor(scale, name="scale")
self._concentration = ops.convert_to_tensor(
concentration, name="concentration")
check_ops.assert_same_float_dtype([self._scale, self._concentration])
if validate_args:
self._scale = control_flow_ops.with_dependencies([
check_ops.assert_positive(
self._scale,
message="Argument scale was not positive")
], self._scale)
self._concentration = control_flow_ops.with_dependencies([
check_ops.assert_positive(
self._concentration,
message="Argument concentration was not positive")
], self._concentration)
super(Weibull, self).__init__(
event_ndims=event_ndims,
validate_args=validate_args,
name=name)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:45,代码来源:weibull.py
示例9: __init__
def __init__(self,
loc,
scale,
validate_args=False,
allow_nan_stats=True,
name="Laplace"):
"""Construct Laplace distribution with parameters `loc` and `scale`.
The parameters `loc` and `scale` must be shaped in a way that supports
broadcasting (e.g., `loc / scale` is a valid operation).
Args:
loc: Floating point tensor which characterizes the location (center)
of the distribution.
scale: Positive floating point tensor which characterizes the spread of
the distribution.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`,
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
indicate the result is undefined. When `False`, an exception is raised
if one or more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
Raises:
TypeError: if `loc` and `scale` are of different dtype.
"""
parameters = locals()
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
self._loc = array_ops.identity(loc, name="loc")
self._scale = array_ops.identity(scale, name="scale")
check_ops.assert_same_float_dtype([self._loc, self._scale])
super(Laplace, self).__init__(
dtype=self._loc.dtype,
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._loc, self._scale],
name=name)
开发者ID:Jackiefan,项目名称:tensorflow,代码行数:44,代码来源:laplace.py
示例10: __init__
def __init__(self,
low=0.,
high=1.,
validate_args=False,
allow_nan_stats=True,
name="Uniform"):
"""Initialize a batch of Uniform distributions.
Args:
low: Floating point tensor, lower boundary of the output interval. Must
have `low < high`.
high: Floating point tensor, upper boundary of the output interval. Must
have `low < high`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
result is undefined. When `False`, an exception is raised if one or
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
Raises:
InvalidArgumentError: if `low >= high` and `validate_args=False`.
"""
parameters = locals()
with ops.name_scope(name, values=[low, high]):
with ops.control_dependencies([
check_ops.assert_less(
low, high, message="uniform not defined when low >= high.")
] if validate_args else []):
self._low = array_ops.identity(low, name="low")
self._high = array_ops.identity(high, name="high")
check_ops.assert_same_float_dtype([self._low, self._high])
super(Uniform, self).__init__(
dtype=self._low.dtype,
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._low,
self._high],
name=name)
开发者ID:LUTAN,项目名称:tensorflow,代码行数:44,代码来源:uniform.py
示例11: __init__
def __init__(self,
loc,
scale,
validate_args=False,
allow_nan_stats=True,
name="Logistic"):
"""Construct Logistic distributions with mean and scale `loc` and `scale`.
The parameters `loc` and `scale` must be shaped in a way that supports
broadcasting (e.g. `loc + scale` is a valid operation).
Args:
loc: Floating point tensor, the means of the distribution(s).
scale: Floating point tensor, the scales of the distribution(s). Must
contain only positive values.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
result is undefined. When `False`, an exception is raised if one or
more of the statistic's batch members are undefined.
name: The name to give Ops created by the initializer.
Raises:
TypeError: if loc and scale are different dtypes.
"""
parameters = distribution_util.parent_frame_arguments()
with ops.name_scope(name, values=[loc, scale]) as name:
with ops.control_dependencies([check_ops.assert_positive(scale)] if
validate_args else []):
self._loc = array_ops.identity(loc, name="loc")
self._scale = array_ops.identity(scale, name="scale")
check_ops.assert_same_float_dtype([self._loc, self._scale])
super(Logistic, self).__init__(
dtype=self._scale.dtype,
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
parameters=parameters,
graph_parents=[self._loc, self._scale],
name=name)
开发者ID:didukhle,项目名称:tensorflow,代码行数:43,代码来源:logistic.py
示例12: __init__
def __init__(self,
skewness=None,
tailweight=None,
validate_args=False,
name="SinhArcsinh"):
"""Instantiates the `SinhArcsinh` bijector.
Args:
skewness: Skewness parameter. Float-type `Tensor`. Default is `0`
of type `float32`.
tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as
`skewness` and broadcastable `shape`. Default is `1` of type `float32`.
validate_args: Python `bool` indicating whether arguments should be
checked for correctness.
name: Python `str` name given to ops managed by this object.
"""
self._graph_parents = []
self._name = name
self._validate_args = validate_args
with self._name_scope("init", values=[skewness, tailweight]):
tailweight = 1. if tailweight is None else tailweight
skewness = 0. if skewness is None else skewness
self._skewness = ops.convert_to_tensor(
skewness, name="skewness")
self._tailweight = ops.convert_to_tensor(
tailweight, name="tailweight", dtype=self._skewness.dtype)
check_ops.assert_same_float_dtype([self._skewness, self._tailweight])
if validate_args:
self._tailweight = control_flow_ops.with_dependencies([
check_ops.assert_positive(
self._tailweight,
message="Argument tailweight was not positive")
], self._tailweight)
super(SinhArcsinh, self).__init__(
forward_min_event_ndims=0,
validate_args=validate_args,
name=name)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:37,代码来源:sinh_arcsinh.py
示例13: __init__
def __init__(self,
distribution,
low=None,
high=None,
validate_args=False,
name="QuantizedDistribution"):
"""Construct a Quantized Distribution representing `Y = ceiling(X)`.
Some properties are inherited from the distribution defining `X`. Example:
`allow_nan_stats` is determined for this `QuantizedDistribution` by reading
the `distribution`.
Args:
distribution: The base distribution class to transform. Typically an
instance of `Distribution`.
low: `Tensor` with same `dtype` as this distribution and shape
able to be added to samples. Should be a whole number. Default `None`.
If provided, base distribution's `prob` should be defined at
`low`.
high: `Tensor` with same `dtype` as this distribution and shape
able to be added to samples. Should be a whole number. Default `None`.
If provided, base distribution's `prob` should be defined at
`high - 1`.
`high` must be strictly greater than `low`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
name: Python `str` name prefixed to Ops created by this class.
Raises:
TypeError: If `dist_cls` is not a subclass of
`Distribution` or continuous.
NotImplementedError: If the base distribution does not implement `cdf`.
"""
parameters = locals()
values = (
list(distribution.parameters.values()) +
[low, high])
with ops.name_scope(name, values=values):
self._dist = distribution
if low is not None:
low = ops.convert_to_tensor(low, name="low")
if high is not None:
high = ops.convert_to_tensor(high, name="high")
check_ops.assert_same_float_dtype(
tensors=[self.distribution, low, high])
# We let QuantizedDistribution access _graph_parents since this class is
# more like a baseclass.
graph_parents = self._dist._graph_parents # pylint: disable=protected-access
checks = []
if low is not None and high is not None:
message = "low must be strictly less than high."
checks.append(
check_ops.assert_less(
low, high, message=message))
self._validate_args = validate_args # self._check_integer uses this.
with ops.control_dependencies(checks if validate_args else []):
if low is not None:
self._low = self._check_integer(low)
graph_parents += [self._low]
else:
self._low = None
if high is not None:
self._high = self._check_integer(high)
graph_parents += [self._high]
else:
self._high = None
super(QuantizedDistribution, self).__init__(
dtype=self._dist.dtype,
reparameterization_type=distributions.NOT_REPARAMETERIZED,
validate_args=validate_args,
allow_nan_stats=self._dist.allow_nan_stats,
parameters=parameters,
graph_parents=graph_parents,
name=name)
开发者ID:finardi,项目名称:tensorflow,代码行数:80,代码来源:quantized_distribution.py
示例14: __init__
#.........这里部分代码省略.........
Raises:
ValueError: If `is_X` flags are set in an inconsistent way.
"""
# TODO(langmore) support complex types.
# Complex types are not allowed due to tf.cholesky() requiring float.
# If complex dtypes are allowed, we update the following
# 1. is_diag_update_positive should still imply that `diag > 0`, but we need
# to remind the user that this implies diag is real. This is needed
# because if diag has non-zero imaginary part, it will not be
# self-adjoint positive definite.
dtype = base_operator.dtype
allowed_dtypes = [
dtypes.float16,
dtypes.float32,
dtypes.float64,
]
if dtype not in allowed_dtypes:
raise TypeError(
"Argument matrix must have dtype in %s. Found: %s"
% (allowed_dtypes, dtype))
if diag_update is None:
if is_diag_update_positive is False:
raise ValueError(
"Default diagonal is the identity, which is positive. However, "
"user set 'is_diag_update_positive' to False.")
is_diag_update_positive = True
# In this case, we can use a Cholesky decomposition to help us solve/det.
self._use_cholesky = (
base_operator.is_positive_definite and base_operator.is_self_adjoint
and is_diag_update_positive
and v is None)
# Possibly auto-set some characteristic flags from None to True.
# If the Flags were set (by the user) incorrectly to False, then raise.
if base_operator.is_self_adjoint and v is None and not dtype.is_complex:
if is_self_adjoint is False:
raise ValueError(
"A = L + UDU^H, with L self-adjoint and D real diagonal. Since"
" UDU^H is self-adjoint, this must be a self-adjoint operator.")
is_self_adjoint = True
# The condition for using a cholesky is sufficient for SPD, and
# we no weaker choice of these hints leads to SPD. Therefore,
# the following line reads "if hints indicate SPD..."
if self._use_cholesky:
if (
is_positive_definite is False
or is_self_adjoint is False
or is_non_singular is False):
raise ValueError(
"Arguments imply this is self-adjoint positive-definite operator.")
is_positive_definite = True
is_self_adjoint = True
values = base_operator.graph_parents + [u, diag_update, v]
with ops.name_scope(name, values=values):
# Create U and V.
self._u = ops.convert_to_tensor(u, name="u")
if v is None:
self._v = self._u
else:
self._v = ops.convert_to_tensor(v, name="v")
if diag_update is None:
self._diag_update = None
else:
self._diag_update = ops.convert_to_tensor(
diag_update, name="diag_update")
# Create base_operator L.
self._base_operator = base_operator
graph_parents = base_operator.graph_parents + [
self.u, self._diag_update, self.v]
graph_parents = [p for p in graph_parents if p is not None]
super(LinearOperatorLowRankUpdate, self).__init__(
dtype=self._base_operator.dtype,
graph_parents=graph_parents,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
# Create the diagonal operator D.
self._set_diag_operators(diag_update, is_diag_update_positive)
self._is_diag_update_positive = is_diag_update_positive
check_ops.assert_same_float_dtype((base_operator, self.u, self.v,
self._diag_update))
self._check_shapes()
# Pre-compute the so-called "capacitance" matrix
# C := D^{-1} + V^H L^{-1} U
self._capacitance = self._make_capacitance()
if self._use_cholesky:
self._chol_capacitance = linalg_ops.cholesky(self._capacitance)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:101,代码来源:linear_operator_low_rank_update.py
示例15: __init__
def __init__(self,
df,
scale_operator_pd,
cholesky_input_output_matrices=False,
validate_args=False,
allow_nan_stats=True,
name=None):
"""Construct Wishart distributions.
Args:
df: `float` or `double` tensor, the degrees of freedom of the
distribution(s). `df` must be greater than or equal to `k`.
scale_operator_pd: `float` or `double` instance of `OperatorPDBase`.
cholesky_input_output_matrices: Python `bool`. Any function which whose
input or output is a matrix assumes the input is Cholesky and returns a
Cholesky factored matrix. Example `log_prob` input takes a Cholesky and
`sample_n` returns a Cholesky when
`cholesky_input_output_matrices=True`.
validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
allow_nan_stats: Python `bool`, default `True`. When `True`, statistics
(e.g., mean, mode, variance) use the value "`NaN`" to indicate the
result is undefined. When `False`, an exception is raised if one or
more of the statistic's batch members are undefined.
name: Python `str` name prefixed to Ops created by this class.
Raises:
TypeError: if scale is not floating-type
TypeError: if scale.dtype != df.dtype
ValueError: if df < k, where scale operator event shape is
`(k, k)`
"""
parameters = locals()
self._cholesky_input_output_matrices = cholesky_input_output_matrices
with ops.name_scope(name):
with ops.name_scope("init", values=[df, scale_operator_pd]):
if not scale_operator_pd.dtype.is_floating:
raise TypeError(
"scale_operator_pd.dtype=%s is not a floating-point type" %
scale_operator_pd.dtype)
self._scale_operator_pd = scale_operator_pd
self._df = ops.convert_to_tensor(
df,
dtype=scale_operator_pd.dtype,
name="df")
check_ops.assert_same_float_dtype(
(self._df, self._scale_operator_pd))
if (self._scale_operator_pd.get_shape().ndims is None or
self._scale_operator_pd.get_shape()[-1].value is None):
self._dimension = math_ops.cast(
self._scale_operator_pd.vector_space_dimension(),
dtype=self._scale_operator_pd.dtype, name="dimension")
else:
self._dimension = ops.convert_to_tensor(
self._scale_operator_pd.get_shape()[-1].value,
dtype=self._scale_operator_pd.dtype, name="dimension")
df_val = tensor_util.constant_value(self._df)
dim_val = tensor_util.constant_value(self._dimension)
if df_val is not None and dim_val is not None:
df_val = np.asarray(df_val)
if not df_val.shape:
df_val = [df_val]
if any(df_val < dim_val):
raise ValueError(
"Degrees of freedom (df = %s) cannot be less than "
"dimension of scale matrix (scale.dimension = %s)"
% (df_val, dim_val))
elif validate_args:
assertions = check_ops.assert_less_equal(
self._dimension, self._df,
message=("Degrees of freedom (df = %s) cannot be "
"less than dimension of scale matrix "
"(scale.dimension = %s)" %
(self._dimension, self._df)))
self._df = control_flow_ops.with_dependencies(
[assertions], self._df)
super(_WishartOperatorPD, self).__init__(
dtype=self._scale_operator_pd.dtype,
validate_args=validate_args,
allow_nan_stats=allow_nan_stats,
reparameterization_type=distribution.FULLY_REPARAMETERIZED,
parameters=parameters,
graph_parents=([self._df, self._dimension] +
self._scale_operator_pd.inputs),
name=name)
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:87,代码来源:wishart.py
注:本文中的tensorflow.python.ops.check_ops.assert_same_float_dtype函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论