• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python function_utils.fn_args函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.util.function_utils.fn_args函数的典型用法代码示例。如果您正苦于以下问题:Python fn_args函数的具体用法?Python fn_args怎么用?Python fn_args使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了fn_args函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: add

 def add(self, layer_func):
   if isinstance(layer_func, base.Layer):
     args = function_utils.fn_args(layer_func.call)
     self.track_layer(layer_func)
   elif callable(layer_func):
     args = function_utils.fn_args(layer_func)
   else:
     raise TypeError(
         "Sequential.add() takes only tf.layers.Layer objects or callables; "
         "not '%s' of type '%s'." % (layer_func, type(layer_func)))
   self._layers_funcs.append((("training" in args), layer_func))
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:11,代码来源:network.py


示例2: eval_step

    def eval_step():
      """A single step of evaluation."""
      estimator_spec = self._call_model_fn(features, labels,
                                           model_fn_lib.ModeKeys.EVAL, params)

      try:
        captured_scaffold_fn.capture(estimator_spec.scaffold_fn)
      except AttributeError:
        captured_scaffold_fn.capture(None)

      eval_metric_fn = None
      eval_metric_fn_tensors = []
      try:
        if estimator_spec.eval_metrics:
          (eval_metric_fn, eval_metric_fn_tensors) = estimator_spec.eval_metrics
      except AttributeError:
        pass

      # If a dictionary is provided, we need to convert it into a list sorted
      # according to order of eval_metric_fn positional arguments.
      if isinstance(eval_metric_fn_tensors, dict):
        eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)
        eval_metric_fn_tensors = [
            eval_metric_fn_tensors[i] for i in eval_metric_fn_args
        ]

      captured_eval_metric_fn.capture(eval_metric_fn)

      return tuple([estimator_spec.loss] + eval_metric_fn_tensors)
开发者ID:baojianzhou,项目名称:tensorflow,代码行数:29,代码来源:xla.py


示例3: call

    def call(*args):
      kwargs = dict(
          zip(function_utils.fn_args(getattr(self._type, name))[1:], args))
      specs = self._type._tensor_specs(name, kwargs, self._constructor_kwargs)

      if specs is None:
        raise ValueError(
            'No tensor specifications were provided for: %s' % name)

      flat_dtypes = nest.flatten(nest.map_structure(lambda s: s.dtype, specs))
      flat_shapes = nest.flatten(nest.map_structure(lambda s: s.shape, specs))

      def py_call(*args):
        try:
          self._out.send(args)
          result = self._out.recv()
          if isinstance(result, Exception):
            raise result
          if result is not None:
            return result
        except Exception as e:
          if isinstance(e, IOError):
            raise StopIteration()  # Clean exit.
          else:
            raise

      result = tf.py_func(py_call, (name,) + tuple(args), flat_dtypes,
                          name=name)

      if isinstance(result, tf.Operation):
        return result

      for t, shape in zip(result, flat_shapes):
        t.set_shape(shape)
      return nest.pack_sequence_as(specs, result)
开发者ID:reinforcementdriving,项目名称:scalable_agent,代码行数:35,代码来源:py_process.py


示例4: run_step_fn

  def run_step_fn(self, step_fn):
    """Run ops using a step function.

    Args:
      step_fn: A function or a method with a single argument of type
        `StepContext`.  The function may use methods of the argument to
        perform computations with access to a raw session.

        The returned value of the `step_fn` will be returned from `run_step_fn`,
        unless a stop is requested.  In that case, the next `should_stop` call
        will return True.

        Example usage:

        ```python
           with tf.Graph().as_default():
             c = tf.placeholder(dtypes.float32)
             v = tf.add(c, 4.0)
             w = tf.add(c, 0.5)

             def step_fn(step_context):
               a = step_context.session.run(fetches=v, feed_dict={c: 0.5})
               if a <= 4.5:
                 step_context.request_stop()
               return step_context.run_with_hooks(fetches=w, feed_dict={c: 0.1})

             with tf.MonitoredSession() as session:
               while not session.should_stop():
                 a = session.run_step_fn(step_fn)
        ```

        Hooks interact with the `run_with_hooks()` call inside the `step_fn`
        as they do with a `MonitoredSession.run` call.

    Returns:
      Returns the returned value of `step_fn`.

    Raises:
      StopIteration: if `step_fn` has called `request_stop()`.  It may be
        caught by `with tf.MonitoredSession()` to close the session.
      ValueError: if `step_fn` doesn't have a single argument called
        `step_context`. It may also optionally have `self` for cases when it
        belongs to an object.
    """
    step_fn_arguments = function_utils.fn_args(step_fn)
    if step_fn_arguments != ('step_context',) and step_fn_arguments != (
        'self',
        'step_context',
    ):
      raise ValueError(
          '`step_fn` may either have one `step_context` argument, or'
          ' `self` and `step_context` arguments if it\'s an instance'
          ' method. Got {} instead.'.format(step_fn_arguments))

    # `self._sess` is either `_RecoverableSession` or a `_CoordinatedSession`.
    # Setting `run_with_hooks` to `None` will cause `run_with_hooks` to be
    # `_CoordinatedSession.run` downstream in either case. This allows
    # `_PREEMPTION_ERRORS` to propage from within `step_fn` to
    # `_RecoverableSession.run_step_fn`.
    return self._sess.run_step_fn(step_fn, self._tf_sess(), run_with_hooks=None)
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:60,代码来源:monitored_session.py


示例5: test_bounded_method

  def test_bounded_method(self):

    class Foo(object):

      def bar(self, a, b):
        return a + b

    self.assertEqual(('a', 'b'), function_utils.fn_args(Foo().bar))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:8,代码来源:function_utils_test.py


示例6: test_callable

  def test_callable(self):

    class Foo(object):

      def __call__(self, a, b):
        return a + b

    self.assertEqual(('a', 'b'), function_utils.fn_args(Foo()))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:8,代码来源:function_utils_test.py


示例7: __init__

  def __init__(self, type_, *constructor_args, **constructor_kwargs):
    self._type = type_
    self._constructor_kwargs = dict(
        zip(function_utils.fn_args(type_.__init__)[1:], constructor_args))
    self._constructor_kwargs.update(constructor_kwargs)

    tf.add_to_collection(PyProcess.COLLECTION, self)

    self._proxy = _TFProxy(type_, self._constructor_kwargs)
开发者ID:reinforcementdriving,项目名称:scalable_agent,代码行数:9,代码来源:py_process.py


示例8: _get_standardized_predicate_fn

def _get_standardized_predicate_fn(predicate_fn):
  pred_fn_args = function_utils.fn_args(predicate_fn)
  if "checkpoint_path" not in pred_fn_args:
    # pylint: disable=unused-argument
    def _pred_fn_wrapper(eval_results, checkpoint_path):
      return predicate_fn(eval_results)

    return _pred_fn_wrapper
  else:
    return predicate_fn
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:10,代码来源:experiment.py


示例9: _verify_estimator_spec

  def _verify_estimator_spec(self, estimator_spec):
    """Verifies estimator spec contains correct data."""
    # TODO(ycao): Implement estimator spec verification for other modes.

    try:
      if estimator_spec.scaffold:
        logging.warning('EstimatorSpec.scaffold is ignored with XLA compilation'
                        '. Please use TPUEstimatorSpec.scaffold_fn instead.')
    except AttributeError:
      pass

    try:
      if estimator_spec.eval_metric_ops:
        raise ValueError('EstimatorSpec.eval_metric_ops is not supported with '
                         'XLA compilation. Please use '
                         'TPUEstimatorSpec.eval_metrics instead.')
    except AttributeError:
      pass

    if estimator_spec.mode == model_fn_lib.ModeKeys.EVAL:
      # If estimator_spec is of type TPUEstimatorSpec and contains eval_metrics,
      # check that eval_metrics contains eval_metric_fn and
      # eval_metric_fn_tensors with matching arguments.
      try:
        eval_metrics = estimator_spec.eval_metrics
      except AttributeError:
        eval_metrics = None

      if eval_metrics:
        (eval_metric_fn, eval_metric_fn_tensors) = eval_metrics
        eval_metric_fn_args = function_utils.fn_args(eval_metric_fn)

        if isinstance(eval_metric_fn_tensors, dict):
          missing_tensors = [
              i for i in eval_metric_fn_args if i not in eval_metric_fn_tensors
          ]
          additional_tensors = [
              i for i in eval_metric_fn_tensors if i not in eval_metric_fn_args
          ]

          if missing_tensors:
            raise ValueError('Arguments %s are needed by metric_fn (first '
                             'element of TPUEstimatorSpec.eval_metrics) but '
                             'they are not provided by evaluation tensors '
                             '(second element of TPUEstimatorSpec.eval_metrics)'
                             '.' % missing_tensors)

          if additional_tensors:
            raise ValueError('Arguments %s are provided by evaluation tensors '
                             '(second element of TPUEstimatorSpec.eval_metrics)'
                             ' but they are not needed by metric_fn (first '
                             'element of TPUEstimatorSpec.eval_metrics).' %
                             additional_tensors)

    return estimator_spec
开发者ID:baojianzhou,项目名称:tensorflow,代码行数:55,代码来源:xla.py


示例10: test_partial_function

  def test_partial_function(self):
    expected_test_arg = 123

    def fn(a, test_arg):
      if test_arg != expected_test_arg:
        return ValueError('partial fn does not work correctly')
      return a

    wrapped_fn = functools.partial(fn, test_arg=123)

    self.assertEqual(('a',), function_utils.fn_args(wrapped_fn))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:11,代码来源:function_utils_test.py


示例11: _call_metric_fn

def _call_metric_fn(metric_fn, features, labels, predictions, config):
  """Calls metric fn with proper arguments."""
  metric_fn_args = function_utils.fn_args(metric_fn)
  kwargs = {}
  if 'features' in metric_fn_args:
    kwargs['features'] = features
  if 'labels' in metric_fn_args:
    kwargs['labels'] = labels
  if 'predictions' in metric_fn_args:
    kwargs['predictions'] = predictions
  if 'config' in metric_fn_args:
    kwargs['config'] = config
  return metric_fn(**kwargs)
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:13,代码来源:extenders.py


示例12: test_double_partial

  def test_double_partial(self):
    expected_test_arg1 = 123
    expected_test_arg2 = 456

    def fn(a, test_arg1, test_arg2):
      if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
        return ValueError('partial does not work correctly')
      return a

    wrapped_fn = functools.partial(fn, test_arg2=456)
    double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)

    self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:13,代码来源:function_utils_test.py


示例13: test_double_partial_with_positional_args_in_both_layers

  def test_double_partial_with_positional_args_in_both_layers(self):
    expected_test_arg1 = 123
    expected_test_arg2 = 456

    def fn(test_arg1, test_arg2, a):
      if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
        return ValueError('partial fn does not work correctly')
      return a

    wrapped_fn = functools.partial(fn, 123)  # binds to test_arg1
    double_wrapped_fn = functools.partial(wrapped_fn, 456)  # binds to test_arg2

    self.assertEqual(('a',), function_utils.fn_args(double_wrapped_fn))

    self.assertEqual(3, double_wrapped_fn(3))
    self.assertEqual(3, double_wrapped_fn(a=3))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:16,代码来源:function_utils_test.py


示例14: _call_model_fn

  def _call_model_fn(self, features, labels, mode, params):
    """Calls the model_fn with required parameters."""
    model_fn_args = function_utils.fn_args(self._model_fn)
    kwargs = {}

    if 'labels' in model_fn_args:
      kwargs['labels'] = labels
    elif labels is not None:
      raise ValueError(
          'model_fn does not take labels, but input_fn returns labels.')
    if 'mode' in model_fn_args:
      kwargs['mode'] = mode

    if 'params' in model_fn_args:
      kwargs['params'] = params

    return self._verify_estimator_spec(
        self._model_fn(features=features, **kwargs))
开发者ID:baojianzhou,项目名称:tensorflow,代码行数:18,代码来源:xla.py


示例15: call_logit_fn

def call_logit_fn(logit_fn, features, mode, params, config):
  """Calls logit_fn.

  A utility function that calls the provided logit_fn with the relevant subset
  of provided arguments.  Similar to tf.estimator._call_model_fn().

  Args:
    logit_fn: A logit_fn as defined above.
    features: The features dict.
    mode: TRAIN / EVAL / PREDICT ModeKeys.
    params: The hyperparameter dict.
    config: The configuration object.

  Returns:
    A logit Tensor, the output of logit_fn.

  Raises:
    ValueError: if logit_fn does not return a Tensor or a dictionary mapping
      strings to Tensors.
  """
  logit_fn_args = function_utils.fn_args(logit_fn)
  kwargs = {}
  if 'mode' in logit_fn_args:
    kwargs['mode'] = mode
  if 'params' in logit_fn_args:
    kwargs['params'] = params
  if 'config' in logit_fn_args:
    kwargs['config'] = config
  logit_fn_results = logit_fn(features=features, **kwargs)

  result_is_valid_dictionary = (
      isinstance(logit_fn_results, dict) and
      all([(isinstance(k, six.string_types) and isinstance(v, ops.Tensor))
           for k, v in six.iteritems(logit_fn_results)]))
  result_is_tensor = isinstance(logit_fn_results, ops.Tensor)

  if not (result_is_valid_dictionary or result_is_tensor):
    raise ValueError('logit_fn should return a Tensor or a dictionary mapping '
                     'strings to Tensors.  logit_fn returned: %s' %
                     logit_fn_results)

  return logit_fn_results
开发者ID:AnishShah,项目名称:tensorflow,代码行数:42,代码来源:logit_fns.py


示例16: _validate_properties

def _validate_properties(run_config):
  """Validates the properties."""
  def _validate(property_name, cond, message):
    property_value = getattr(run_config, property_name)
    if property_value is not None and not cond(property_value):
      raise ValueError(message)

  _validate('model_dir', lambda dir: dir,
            message='model_dir should be non-empty')

  _validate('save_summary_steps', lambda steps: steps >= 0,
            message='save_summary_steps should be >= 0')

  _validate('save_checkpoints_steps', lambda steps: steps >= 0,
            message='save_checkpoints_steps should be >= 0')
  _validate('save_checkpoints_secs', lambda secs: secs >= 0,
            message='save_checkpoints_secs should be >= 0')

  _validate('session_config',
            lambda sc: isinstance(sc, config_pb2.ConfigProto),
            message='session_config must be instance of ConfigProto')

  _validate('keep_checkpoint_max', lambda keep_max: keep_max >= 0,
            message='keep_checkpoint_max should be >= 0')
  _validate('keep_checkpoint_every_n_hours', lambda keep_hours: keep_hours > 0,
            message='keep_checkpoint_every_n_hours should be > 0')
  _validate('log_step_count_steps', lambda num_steps: num_steps > 0,
            message='log_step_count_steps should be > 0')

  _validate('tf_random_seed', lambda seed: isinstance(seed, six.integer_types),
            message='tf_random_seed must be integer.')

  _validate('device_fn', lambda device_fn: six.callable(device_fn) and
            set(function_utils.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,
            message='device_fn must be callable with exactly'
                    ' one argument "op".')

  _validate('protocol',
            lambda protocol: protocol in (None, "grpc", "grpc+verbs"),
            message='protocol should be grpc or grpc+verbs')
开发者ID:AnishShah,项目名称:tensorflow,代码行数:40,代码来源:run_config.py


示例17: test_simple_function

 def test_simple_function(self):
   def fn(a, b):
     return a + b
   self.assertEqual(('a', 'b'), function_utils.fn_args(fn))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:4,代码来源:function_utils_test.py


示例18: _verify_metric_fn_args

def _verify_metric_fn_args(metric_fn):
  args = set(function_utils.fn_args(metric_fn))
  invalid_args = list(args - _VALID_METRIC_FN_ARGS)
  if invalid_args:
    raise ValueError('metric_fn (%s) has following not expected args: %s' %
                     (metric_fn, invalid_args))
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:6,代码来源:extenders.py


示例19: __call__

  def __call__(self, inputs, *args, **kwargs):
    """Wraps `call`, applying pre- and post-processing steps.

    Arguments:
      inputs: input tensor(s).
      *args: additional positional arguments to be passed to `self.call`.
      **kwargs: additional keyword arguments to be passed to `self.call`.
        **Note**: kwarg `scope` is reserved for use by the layer.

    Returns:
      Output tensor(s).

    Note:
      - If the layer's `call` method takes a `scope` keyword argument,
        this argument will be automatically set to the current variable scope.
      - If the layer's `call` method takes a `mask` argument (as some Keras
        layers do), its default value will be set to the mask generated
        for `inputs` by the previous layer (if `input` did come from
        a layer that generated a corresponding mask, i.e. if it came from
        a Keras layer with masking support.

    Raises:
      ValueError: if the layer's `call` method returns None (an invalid value).
    """
    scope = kwargs.pop('scope', None)

    if self._keras_style:
      if scope is not None:
        raise ValueError(
            'scope argument not allowed when keras style layers are enabled, '
            'but saw: {}'.format(scope))
      return super(Layer, self).__call__(inputs, *args, **kwargs)

    self._set_scope(scope)

    if not context.executing_eagerly():
      try:
        # Set layer's "graph" at build time
        self._graph = ops._get_graph_from_inputs(nest.flatten(inputs),  # pylint: disable=protected-access
                                                 graph=self._graph)
      except ValueError as e:
        raise ValueError('Input graph and Layer graph are not the same: %s' % e)

    if self.built:
      try:
        # Some classes which inherit from Layer do not use its constructor, so
        # rather than initializing to None we check for an AttributeError.
        scope_context_manager = self._always_reuse_variable_scope
      except AttributeError:
        # From this point we will always set reuse=True, so create a "final"
        # variable scope with this setting. We avoid re-creating variable scopes
        # after this point as an optimization.
        self._always_reuse_variable_scope = vs.variable_scope(
            self._scope, reuse=True, auxiliary_name_scope=False)
        scope_context_manager = self._always_reuse_variable_scope
    else:
      scope_context_manager = vs.variable_scope(
          self._scope, reuse=self._reuse, auxiliary_name_scope=False)

    with scope_context_manager as scope:
      self._current_scope = scope

      try:
        call_has_scope_arg = self._call_has_scope_arg
      except AttributeError:
        self._call_fn_args = function_utils.fn_args(self.call)
        self._call_has_scope_arg = 'scope' in self._call_fn_args
        call_has_scope_arg = self._call_has_scope_arg
      if call_has_scope_arg:
        kwargs['scope'] = scope

      # Actually call layer
      outputs = super(Layer, self).__call__(inputs, *args, **kwargs)

    if not context.executing_eagerly():
      # Update global default collections.
      _add_elements_to_collection(self.updates, ops.GraphKeys.UPDATE_OPS)
    return outputs
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:78,代码来源:base.py


示例20: _get_loss_towers

def _get_loss_towers(model_fn,
                     mode,
                     features,
                     labels,
                     params,
                     config,
                     devices,
                     local_ps_devices,
                     loss_reduction,
                     name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN):
  """Replicate the loss computation across devices."""
  tower_specs = []

  model_fn_args = function_utils.fn_args(model_fn)
  optional_params = {}
  if 'params' in model_fn_args:
    optional_params['params'] = copy.deepcopy(params)
  if 'config' in model_fn_args:
    optional_params['config'] = copy.deepcopy(config)

  # pylint: disable=protected-access
  round_robin_strategy = device_setter_lib._RoundRobinStrategy(
      num_tasks=len(local_ps_devices))
  TowerOptimizer._graph_state().set_reduction_across_towers(
      loss_reduction, len(devices))

  for i, device in enumerate(devices):
    is_the_first_tower = (i == 0)

    device_setter = _local_device_setter(
        worker_device=device,
        ps_devices=local_ps_devices,
        ps_strategy=round_robin_strategy)

    # We would like to preserve the names of the variables and ops that the user
    # might be relying on. Names without a prefix are going to resolve to
    # variables and ops of the first tower.
    name_scope = name_scope_pattern
    if is_the_first_tower:
      name_scope = ''

    with variable_scope.variable_scope(
        '', reuse=not is_the_first_tower) as var_scope:
      with ops_lib.name_scope(name_scope.format(i)) as name_scope:
        with TowerOptimizer._graph_state().tower(
            tower_id=i, var_scope=var_scope, name_scope=name_scope):
          with ops_lib.device(device_setter):
            labels_shard = None
            if labels:
              labels_shard = labels[i]

            tower_spec = model_fn(
                mode=mode,
                features=features[i],
                labels=labels_shard,
                **optional_params)

            if (tower_spec.train_op is not None and len(devices) > 1 and
                not TowerOptimizer.has_been_used()):
              raise ValueError('Please wrap optimizers with TowerOptimizer'
                               ' in order to use replicate_model_fn with'
                               ' multiple `devices`.')

            # Scaling the loss here doesn't actually affect gradients.  Another
            # instance of scaling happens inside the TowerOptimizer.
            tower_spec = _scale_tower_loss(
                tower_spec, loss_reduction, number_of_towers=len(devices))
            tower_specs.append(tower_spec)

  if not TowerOptimizer._did_towers_have_same_optimizer_calls():
    raise ValueError('Each invocation of model_fn was supposed to make the same'
                     ' optimizer calls.')
  TowerOptimizer._clear_graph_state()
  # pylint: enable=protected-access
  return tower_specs
开发者ID:AnishShah,项目名称:tensorflow,代码行数:75,代码来源:replicate_model_fn.py



注:本文中的tensorflow.python.util.function_utils.fn_args函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python nest.assert_same_structure函数代码示例发布时间:2022-05-27
下一篇:
Python decorator_utils.validate_callable函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap