• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python context.graph_mode函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.eager.context.graph_mode函数的典型用法代码示例。如果您正苦于以下问题:Python graph_mode函数的具体用法?Python graph_mode怎么用?Python graph_mode使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了graph_mode函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: _worker_fn

    def _worker_fn(task_type, task_id, num_gpus):
      del num_gpus
      tf_config = {
          "cluster": self._cluster_spec,
          "task": {
              "type": task_type,
              "index": task_id
          }
      }
      with context.graph_mode(), lock, test.mock.patch.dict(
          "os.environ", {"TF_CONFIG": json.dumps(tf_config)}):
        strategy = strategy_cls()
      with context.graph_mode(), strategy.scope(), self.cached_session(
          target="grpc://" + self._cluster_spec[task_type][task_id]) as sess:
        if tf2.enabled():
          dataset_fn = lambda _: dataset_ops.DatasetV2.range(5).batch(2)
        else:
          dataset_fn = lambda _: dataset_ops.Dataset.range(5).batch(2)
        if (input_type == "dataset" and strategy_cls is
            collective_all_reduce_strategy.CollectiveAllReduceStrategy):
          # Autosharded
          if task_id == 0:
            expected_values = [[[0, 1]], [[4]]]
          else:
            expected_values = [[[2, 3]], [[]]]

          # input_context is for between-graph auto-sharding.
          input_context = distribute_lib.InputContext(
              num_input_pipelines=2,
              input_pipeline_id=task_id,
              num_replicas_in_sync=2)
        else:
          expected_values = [[[0, 1]], [[2, 3]], [[4]]]
          input_context = None

        self._test_input_iteration(
            input_type,
            api_type,
            iteration_type,
            dataset_fn,
            [("/job:%s/task:%d" %
              (task_type, task_id), strategy.extended.worker_devices)],
            expected_values,
            strategy,
            sess=sess,
            enable_get_next_as_optional=True,
            input_context=input_context)
        return True
开发者ID:aritratony,项目名称:tensorflow,代码行数:48,代码来源:input_lib_test.py


示例2: test_build_standardized_signature_def_classify_classes_only

  def test_build_standardized_signature_def_classify_classes_only(self):
    """Tests classification with one output tensor."""
    with context.graph_mode():
      input_tensors = {
          'input-1':
              array_ops.placeholder(
                  dtypes.string, 1, name='input-tensor-1')
      }
      classes = array_ops.placeholder(dtypes.string, 1, name='output-tensor-1')

      export_output = export_output_lib.ClassificationOutput(classes=classes)
      actual_signature_def = export_output.as_signature_def(input_tensors)

      expected_signature_def = meta_graph_pb2.SignatureDef()
      shape = tensor_shape_pb2.TensorShapeProto(
          dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
      dtype_string = types_pb2.DataType.Value('DT_STRING')
      expected_signature_def.inputs[
          signature_constants.CLASSIFY_INPUTS].CopyFrom(
              meta_graph_pb2.TensorInfo(name='input-tensor-1:0',
                                        dtype=dtype_string,
                                        tensor_shape=shape))
      expected_signature_def.outputs[
          signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom(
              meta_graph_pb2.TensorInfo(name='output-tensor-1:0',
                                        dtype=dtype_string,
                                        tensor_shape=shape))

      expected_signature_def.method_name = (
          signature_constants.CLASSIFY_METHOD_NAME)
      self.assertEqual(actual_signature_def, expected_signature_def)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:31,代码来源:export_output_test.py


示例3: testUsageGraph

 def testUsageGraph(self):
   """Expected usage when graph building."""
   with context.graph_mode():
     num_training_steps = 10
     checkpoint_directory = self.get_temp_dir()
     checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
     for training_continuation in range(3):
       with ops.Graph().as_default():
         model = MyModel()
         optimizer = adam.AdamOptimizer(0.001)
         root = util.Checkpoint(
             optimizer=optimizer, model=model,
             global_step=training_util.get_or_create_global_step())
         input_value = constant_op.constant([[3.]])
         train_op = optimizer.minimize(
             model(input_value),
             global_step=root.global_step)
         checkpoint_path = checkpoint_management.latest_checkpoint(
             checkpoint_directory)
         with self.session(graph=ops.get_default_graph()) as session:
           status = root.restore(save_path=checkpoint_path)
           status.initialize_or_restore(session=session)
           if checkpoint_path is None:
             self.assertEqual(0, training_continuation)
             with self.assertRaises(AssertionError):
               status.assert_consumed()
           else:
             status.assert_consumed()
           for _ in range(num_training_steps):
             session.run(train_op)
           root.save(file_prefix=checkpoint_prefix, session=session)
           self.assertEqual((training_continuation + 1) * num_training_steps,
                            session.run(root.global_step))
           self.assertEqual(training_continuation + 1,
                            session.run(root.save_counter))
开发者ID:jackd,项目名称:tensorflow,代码行数:35,代码来源:checkpointable_utils_test.py


示例4: _eager_safe_variable_handle

def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
  """Creates a variable handle with information to do shape inference."""
  container = ops.get_default_graph()._container  # pylint: disable=protected-access
  if container is None:
    container = ""
  handle = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
                                               shared_name=shared_name,
                                               name=name,
                                               container=container)
  if graph_mode:
    return handle

  with context.graph_mode(), ops.Graph().as_default() as graph:
    h = resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
                                            shared_name=shared_name,
                                            name=name,
                                            container=container)

    # Tensor._handle_data contains information for the shape-inference code to
    # know the shape and dtype of the variable pointed to by a handle. Since
    # shape inference doesn't run in eager mode we copy this data here for when
    # the handle is captured by an eager mode function.
    # pylint: disable=protected-access
    handle._handle_data = resource_variable_ops.get_resource_handle_data(h)
    # pylint: enable=protected-access
  # Clean up op->graph->op reference cycles.
  ops.dismantle_graph(graph)
  return handle
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:28,代码来源:parameter_server.py


示例5: testGraphOpNames

  def testGraphOpNames(self):
    """Network operation names should match variable naming."""

    def _check_op_prefixes(expected_prefix, checked_ops):
      for operation in ops.get_default_graph().get_operations():
        if operation.name == "ignore":
          continue
        if operation.name in checked_ops:
          continue
        checked_ops.add(operation.name)
        self.assertStartsWith(expected_start=expected_prefix,
                              actual=operation.name)
        self.assertNotIn("my_network", operation.name[len(expected_prefix):])
        self.assertNotIn("dense", operation.name[len(expected_prefix):])

    with context.graph_mode():
      net = MyNetwork()
      zero = constant_op.constant([[0.]], name="ignore")
      net(zero)
      checked_ops = set()
      _check_op_prefixes(expected_prefix="my_network/dense/",
                         checked_ops=checked_ops)
      net.net2 = net.track_layer(MyNetwork())
      net.net2(zero)
      _check_op_prefixes(expected_prefix="my_network/my_network/dense/",
                         checked_ops=checked_ops)
      MyNetwork()(zero)
      _check_op_prefixes(expected_prefix="my_network_1/dense/",
                         checked_ops=checked_ops)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:29,代码来源:network_test.py


示例6: testNameScopeWithGetVariable

  def testNameScopeWithGetVariable(self):
    def in_cross_tower(_):
      c = variable_scope.get_variable("c", [1])
      return c

    def model_fn():
      b = variable_scope.get_variable("b", [1])
      with ops.name_scope("foo"):
        c = distribute_lib.get_tower_context().merge_call(in_cross_tower)
      return b, c

    dist = mirrored_strategy.MirroredStrategy(
        ["/device:GPU:0", "/device:CPU:0"])

    with context.graph_mode(), dist.scope():
      with ops.name_scope("main"):
        a = variable_scope.get_variable("a", [1])
        result = dist.call_for_each_tower(model_fn, run_concurrently=False)
      result_b = result[0]
      result_c = result[1]
      self.assertIsInstance(result_b, values.DistributedValues)
      self.assertIsInstance(result_c, values.DistributedValues)
      a0, a1 = dist.unwrap(a)
      b0, b1 = dist.unwrap(result_b)
      c0, c1 = dist.unwrap(result_c)
      self.assertEquals("a:0", a0.name)
      self.assertEquals("a/replica_1:0", a1.name)
      self.assertEquals("b:0", b0.name)
      self.assertEquals("b/replica_1:0", b1.name)
      self.assertEquals("c:0", c0.name)
      self.assertEquals("c/replica_1:0", c1.name)
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:31,代码来源:mirrored_strategy_multigpu_test.py


示例7: decorated

    def decorated(self, **kwargs):
      """Decorated the test method."""
      with context.graph_mode():
        with self.test_session(graph, config, use_gpu, force_gpu):
          f(self, **kwargs)

      if reset_test:
        # This decorator runs the wrapped test twice.
        # Reset the test environment between runs.
        self.tearDown()
        self.setUp()

      def run_eager_mode(self, **kwargs):
        if force_gpu:
          gpu_name = gpu_device_name()
          if not gpu_name:
            gpu_name = "/device:GPU:0"
          with context.device(gpu_name):
            f(self)
        elif use_gpu:
          # TODO(xpan): Support softplacement and gpu by default when available.
          f(self, **kwargs)
        else:
          with context.device("/device:CPU:0"):
            f(self, **kwargs)

      if assert_no_eager_garbage:
        run_eager_mode = assert_no_new_tensors(
            assert_no_garbage_created(run_eager_mode))

      with context.eager_mode():
        with IsolateTest():
          run_eager_mode(self, **kwargs)
开发者ID:Lin-jipeng,项目名称:tensorflow,代码行数:33,代码来源:test_util.py


示例8: testAggregateGradients

  def testAggregateGradients(self):

    def fn(x):
      ind1 = constant_op.constant(np.array([0, 1]))
      ind2 = constant_op.constant(np.array([2, 3]))
      ind3 = constant_op.constant(np.array([1, 3]))
      # A mixture of IndexedSlices and dense tensor to aggregate.
      g1 = embedding_ops.embedding_lookup(x, ind1)
      g2 = embedding_ops.embedding_lookup(x, ind2)
      g3 = embedding_ops.embedding_lookup(x, ind3)
      g4 = math_ops.reduce_sum(x * constant_op.constant(2.0))
      return g1 * g2 * g3 * g4

    var_np = np.random.rand(4, 2).astype(np.float32)
    var = constant_op.constant(var_np)
    grad = backprop.gradients_function(fn, [0])(var)[0]
    grad = ops.convert_to_tensor(grad).numpy()

    with context.graph_mode(), self.test_session():
      tf_var = array_ops.constant(var_np, dtypes.float32)
      tf_ind1 = array_ops.constant([0, 1])
      tf_ind2 = array_ops.constant([2, 3])
      tf_ind3 = array_ops.constant([1, 3])
      tf_g1 = embedding_ops.embedding_lookup(tf_var, tf_ind1)
      tf_g2 = embedding_ops.embedding_lookup(tf_var, tf_ind2)
      tf_g3 = embedding_ops.embedding_lookup(tf_var, tf_ind3)
      tf_g4 = math_ops.reduce_sum(tf_var * 2.0, reduction_indices=(0, 1))
      tf_y = tf_g1 * tf_g2 * tf_g3 * tf_g4
      tf_grad = gradients.gradients(tf_y, [tf_var])[0]

      tf_dense_grad = math_ops.unsorted_segment_sum(
          tf_grad.values, tf_grad.indices, tf_grad.dense_shape[0])

      self.assertAllClose(grad, tf_dense_grad.eval())
开发者ID:DjangoPeng,项目名称:tensorflow,代码行数:34,代码来源:backprop_test.py


示例9: testAllV2SummaryOps

 def testAllV2SummaryOps(self):
   logdir = self.get_temp_dir()
   def define_ops():
     result = []
     # TF 2.0 summary ops
     result.append(summary_ops.write('write', 1, step=0))
     result.append(summary_ops.write_raw_pb(b'', step=0, name='raw_pb'))
     # TF 1.x tf.contrib.summary ops
     result.append(summary_ops.generic('tensor', 1, step=1))
     result.append(summary_ops.scalar('scalar', 2.0, step=1))
     result.append(summary_ops.histogram('histogram', [1.0], step=1))
     result.append(summary_ops.image('image', [[[[1.0]]]], step=1))
     result.append(summary_ops.audio('audio', [[1.0]], 1.0, 1, step=1))
     return result
   with context.graph_mode():
     ops_without_writer = define_ops()
     with summary_ops.create_file_writer_v2(logdir).as_default():
       with summary_ops.record_if(True):
         ops_recording_on = define_ops()
       with summary_ops.record_if(False):
         ops_recording_off = define_ops()
     # We should be collecting all ops defined with a default writer present,
     # regardless of whether recording was set on or off, but not those defined
     # without a writer at all.
     del ops_without_writer
     expected_ops = ops_recording_on + ops_recording_off
     self.assertCountEqual(expected_ops, summary_ops.all_v2_summary_ops())
开发者ID:aritratony,项目名称:tensorflow,代码行数:27,代码来源:summary_ops_test.py


示例10: testInitializableIterator

  def testInitializableIterator(self):
    with context.graph_mode():
      devices = ["/device:CPU:0"]
      # Using random input since that is only allowed with initializable
      # iterator.
      dataset = dataset_ops.Dataset.from_tensor_slices(
          random_ops.random_uniform((10,)))

      per_device_dataset = values.PerDeviceDataset(
          dataset, devices, prefetch_on_device=False)
      iterator = per_device_dataset.make_initializable_iterator()

      self.evaluate(iterator.initializer)
      next_element = iterator.get_next()
      for _ in range(10):
        self.evaluate(next_element)

      # Should fail after the input is finished.
      with self.assertRaises(errors.OutOfRangeError):
        self.evaluate(next_element)

      # After re-initializing the iterator, should be able to iterate again.
      self.evaluate(iterator.initializer)
      for _ in range(10):
        self.evaluate(next_element)
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:25,代码来源:values_test.py


示例11: testDataDistributionNoAutoShard

 def testDataDistributionNoAutoShard(self):
   worker_devices, devices = self._cpu_devices()
   with context.graph_mode():
     dataset_fn = lambda: dataset_ops.Dataset.range(4)
     self._test_dataset(dataset_fn, worker_devices, devices,
                        [[0, 0], [1, 1], [2, 2], [3, 3]],
                        auto_shard=False)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:7,代码来源:values_test.py


示例12: test_training_no_default

  def test_training_no_default(self):

    with context.graph_mode():
      model = TrainingNoDefaultModel()
      arg = array_ops.ones([1, 1])
      model(arg, True)
      six.assertCountEqual(self, [arg], model.inputs)
开发者ID:StephenOman,项目名称:tensorflow,代码行数:7,代码来源:model_subclassing_test.py


示例13: testDataDistributionOneDevicePerWorker

 def testDataDistributionOneDevicePerWorker(self):
   self.skipTest("Temporarily disabled.")
   worker_device_map, devices = self._cpu_devices()
   with context.graph_mode():
     dataset_fn = lambda: dataset_ops.Dataset.range(8)
     self._test_dataset(dataset_fn, worker_device_map, devices,
                        [[0, 1], [2, 3], [4, 5], [6, 7]])
开发者ID:baojianzhou,项目名称:tensorflow,代码行数:7,代码来源:values_test.py


示例14: _defun_internal

def _defun_internal(name, func, args, kwds):
  """Defines and returns graph-mode version of func."""
  graph_key = ops.get_default_graph()._graph_key  # pylint: disable=protected-access
  with context.graph_mode():
    captures = {}
    tmp_graph = CapturingGraph(captures)
    # Inherit the graph key, since this is used for matching variables in
    # optimizers.
    tmp_graph._graph_key = graph_key  # pylint: disable=protected-access
    # Copy the graph collections to ensure summaries and other things work. This
    # lets the function access (but not mutate) collections of the containing
    # graph, such as the global step and the summary writer collections.
    curr_graph = ops.get_default_graph()
    for collection in curr_graph.collections:
      tmp_graph.get_collection_ref(collection)[:] = curr_graph.get_collection(
          collection)
    with tmp_graph.as_default():
      func_inputs = _get_defun_inputs(args)

      with capture_tensors(captures):
        this_tape = tape.push_new_tape()
        try:
          func_outputs = func(*func_inputs, **kwds)
        finally:
          tape.pop_tape(this_tape)
        variables = this_tape.watched_variables()

        # Returning a closed-over tensor as an output does not trigger a
        # call to convert_to_tensor, so we manually capture all such tensors.
        outputs_list = _flatten(func_outputs)
        func_def_outputs = [
            _convert_to_graph_tensor(x) for x in outputs_list if x is not None
        ]

      ids = list(sorted(captures.keys()))
      if ids:
        extra_inputs, extra_placeholders = zip(* [captures[x] for x in ids])
      else:
        extra_inputs = []
        extra_placeholders = []
      output_shapes = tuple(
          x.shape if isinstance(x, ops.Tensor) else None
          for x in outputs_list)

  flat_inputs = [x for x in nest.flatten(func_inputs)
                 if isinstance(x, ops.Tensor)]
  all_inputs = flat_inputs + list(extra_placeholders)
  all_ignored_ops = frozenset(x.op for x in all_inputs)
  fname = _inference_name(name)
  operations = tuple(x for x in tmp_graph.get_operations()
                     if x not in all_ignored_ops)
  # Register any other functions defined in the graph
  # TODO(ashankar): Oh lord, forgive me for this lint travesty.
  if context.in_eager_mode():
    for f in tmp_graph._functions.values():  # pylint: disable=protected-access
      # TODO(ashankar): What about the gradient registry?
      _register(f._c_func)  # pylint: disable=protected-access
  return GraphModeFunction(
      fname, all_inputs, extra_inputs, tmp_graph, operations, func_def_outputs,
      func_outputs, output_shapes, variables)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:60,代码来源:function.py


示例15: _eager_safe_variable_handle

def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
  """Creates a variable handle with information to do shape inference."""
  container = ops.get_default_graph()._container  # pylint: disable=protected-access
  if container is None:
    container = ""
  handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
                                                   shared_name=shared_name,
                                                   name=name,
                                                   container=container)
  if graph_mode:
    return handle

  # We do not want two distinct ResourceVariable objects for the same
  # underlying resource in the runtime.
  # When in eager mode, explicitly ensure so here. When in graph mode, it's
  # ensured by always generating different variable names.
  exists = gen_resource_variable_ops.var_is_initialized_op(handle)
  if exists:
    raise ValueError("variable object with name '%s' already created. Use "
                     "get_variable() if reuse is desired." %
                     shared_name)
  with context.graph_mode(), ops.Graph().as_default():
    h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype,
                                                shared_name=shared_name,
                                                name=name,
                                                container=container)

    # Tensor._handle_data contains information for the shape-inference code to
    # know the shape and dtype of the variable pointed to by a handle. Since
    # shape inference doesn't run in eager mode we copy this data here for when
    # the handle is captured by an eager mode function.
    handle._handle_data = h._handle_data  # pylint: disable=protected-access
  return handle
开发者ID:DjangoPeng,项目名称:tensorflow,代码行数:33,代码来源:resource_variable_ops.py


示例16: benchmark_keras_model_functional_fit_graph_mode_with_profiler

 def benchmark_keras_model_functional_fit_graph_mode_with_profiler(self):
   profiler.start()
   with context.graph_mode():
     model = make_keras_model(initializer="glorot_uniform")
     self._benchmark_keras_model_fit(model)
   result = profiler.stop()
   assert result is not None
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:7,代码来源:benchmarks_test.py


示例17: _compute_backprop

 def _compute_backprop(self):
   """Computes the backprop function object for this function."""
   self._has_backprop = True
   with self._graph.as_default(), context.graph_mode():
     c = _CapturingContext()
     with c:
       filtered_outputs = [
           x for x in self._returns if x is not None
       ]
       self._out_grad_placeholders = [
           graph_placeholder(x.dtype, x.shape) for x in filtered_outputs
       ]
       in_gradients = gradients_impl.gradients(
           filtered_outputs,
           self._input_placeholders,
           grad_ys=self._out_grad_placeholders)
       shapes = [x.shape for x in in_gradients if x is not None]
   captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
   forward_function_def = make_function_def(
       self._graph, self._ops, self._input_placeholders,
       filtered_outputs + captures)
   self._forward_fdef = _DefinedFunction(forward_function_def)
   _register_with_name(_forward_name(self._func_name), forward_function_def)
   backward_outputs = [x for x in in_gradients if x is not None]
   all_inputs = self._out_grad_placeholders + captures
   backward_function_def = make_function_def(
       self._graph, [x.op for x in self._out_grad_placeholders
                    ] + list(sorted(c.known_ops, key=lambda x: x.name)),
       all_inputs, backward_outputs)
   _register_with_name(_backward_name(self._func_name), backward_function_def)
   self._backward_function = _GraphModeFunction(
       all_inputs, [], backward_function_def, self._graph, c.known_ops,
       in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes)
开发者ID:SylChan,项目名称:tensorflow,代码行数:33,代码来源:function.py


示例18: test_training_no_default

  def test_training_no_default(self):

    with context.graph_mode():
      model = TrainingNoDefaultModel()
      arg = array_ops.ones([1, 1])
      model(arg, True)
      self.assertEqual(len(model.inputs), 1)
开发者ID:aeverall,项目名称:tensorflow,代码行数:7,代码来源:model_subclassing_test.py


示例19: testNamedTupleEstimatorSpec

  def testNamedTupleEstimatorSpec(self):
    with context.graph_mode(), ops.Graph().as_default():
      devices = []
      created_estimator_specs = []

      for device_id in range(3):
        spec = model_fn_lib.EstimatorSpec(
            mode=model_fn_lib.ModeKeys.TRAIN,
            loss=constant_op.constant(device_id / 2),
            train_op=array_ops.identity(constant_op.constant(device_id)))
        devices.append(_device_str(device_id))
        created_estimator_specs.append(spec)

      device_map = values.ReplicaDeviceMap(devices)
      merged_estimator_spec = values.regroup(
          device_map, created_estimator_specs)

      self.assertTrue(
          isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec))
      self.assertEqual(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode)
      for device_id in range(3):
        d = _device_str(device_id)
        self.assertEqual(created_estimator_specs[device_id].loss,
                         merged_estimator_spec.loss.get(d))
        self.assertEqual(created_estimator_specs[device_id].train_op,
                         merged_estimator_spec.train_op.get(d))
        # Scaffold is populated by `EstimatorSpec.__new__`.
        self.assertEqual(created_estimator_specs[device_id].scaffold,
                         merged_estimator_spec.scaffold.get(d))
        # Also test that we can undo the merge using select_replica()
        self.assertEqual(created_estimator_specs[device_id],
                         values.select_replica(device_id,
                                               merged_estimator_spec))
开发者ID:kylin9872,项目名称:tensorflow,代码行数:33,代码来源:values_test.py


示例20: decorated

    def decorated(self):
      """Decorated the test method."""
      with context.graph_mode():
        with self.test_session(graph, config, use_gpu, force_gpu):
          f(self)

      if reset_test:
        # This decorator runs the wrapped test twice.
        # Reset the test environment between runs.
        self.tearDown()
        self.setUp()

      def run_eager_mode():
        if force_gpu:
          gpu_name = gpu_device_name()
          if not gpu_name:
            gpu_name = "/device:GPU:0"
          with context.device(gpu_name):
            f(self)
        elif use_gpu:
          # TODO(xpan): Support softplacement and gpu by default when available.
          f(self)
        else:
          with context.device("/device:CPU:0"):
            f(self)

      eager_graph = graph or ops.Graph()
      with context.eager_mode():
        with eager_graph.as_default():
          run_eager_mode()
开发者ID:keveman,项目名称:tensorflow,代码行数:30,代码来源:test_util.py



注:本文中的tensorflow.python.eager.context.graph_mode函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python context.in_eager_mode函数代码示例发布时间:2022-05-27
下一篇:
Python context.get_default_context函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap