• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python tune.run_experiments函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中ray.tune.run_experiments函数的典型用法代码示例。如果您正苦于以下问题:Python run_experiments函数的具体用法?Python run_experiments怎么用?Python run_experiments使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了run_experiments函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: launch

    def launch(self):
        """Actual entry point into the class instance where everything happens.
        Lots of delegating to classes that are in subclass or can be over-ridden.
        """
        self.register_env_creator()

        # All worker nodes will block at this step during training
        ray_cluster_config = self.ray_init_config()
        if not self.is_master_node:
            return

        # Start the driver on master node
        ray.init(**ray_cluster_config)
        experiment_config = self.get_experiment_config()
        experiment_config = self.customize_experiment_config(experiment_config)
        print("Running experiment with config %s" % json.dumps(experiment_config, indent=2))
        run_experiments(experiment_config)
        all_wokers_host_names = self.get_all_host_names()[1:]
        # If distributed job, send TERMINATION_SIGNAL to all workers.
        if len(all_wokers_host_names) > 0:
            self.sage_cluster_communicator.create_s3_signal(TERMINATION_SIGNAL)

        algo = experiment_config["training"]["run"]
        env_string = experiment_config["training"]["config"]["env"]
        config = experiment_config["training"]["config"]
        self.save_checkpoint_and_serving_model(algorithm=algo,
                                               env_string=env_string,
                                               config=config)
开发者ID:FNDaily,项目名称:amazon-sagemaker-examples,代码行数:28,代码来源:ray_launcher.py


示例2: run

 def run():
     run_experiments(
         {
             "foo": {
                 "run": MyResettableClass,
                 "max_failures": 1,
                 "num_samples": 4,
                 "config": {
                     "fake_reset_not_supported": True
                 },
             }
         },
         reuse_actors=True,
         scheduler=FrequentPausesScheduler())
开发者ID:robertnishihara,项目名称:ray,代码行数:14,代码来源:test_actor_reuse.py


示例3: test_cluster_rllib_restore

def test_cluster_rllib_restore(start_connected_cluster, tmpdir):
    cluster = start_connected_cluster
    dirpath = str(tmpdir)
    script = """
import time
import ray
from ray import tune

ray.init(redis_address="{redis_address}")

kwargs = dict(
    run="PG",
    env="CartPole-v1",
    stop=dict(training_iteration=10),
    local_dir="{checkpoint_dir}",
    checkpoint_freq=1,
    max_failures=1)

tune.run_experiments(
    dict(experiment=kwargs),
    raise_on_failed_trial=False)
""".format(
        redis_address=cluster.redis_address, checkpoint_dir=dirpath)
    run_string_as_driver_nonblocking(script)
    # Wait until the right checkpoint is saved.
    # The trainable returns every 0.5 seconds, so this should not miss
    # the checkpoint.
    metadata_checkpoint_dir = os.path.join(dirpath, "experiment")
    for i in range(100):
        if TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
            # Inspect the internal trialrunner
            runner = TrialRunner.restore(metadata_checkpoint_dir)
            trials = runner.get_trials()
            last_res = trials[0].last_result
            if last_res and last_res.get("training_iteration"):
                break
        time.sleep(0.3)

    if not TrialRunner.checkpoint_exists(metadata_checkpoint_dir):
        raise RuntimeError("Checkpoint file didn't appear.")

    ray.shutdown()
    cluster.shutdown()
    cluster = _start_new_cluster()
    cluster.wait_for_nodes()

    # Restore properly from checkpoint
    trials2 = tune.run_experiments(
        {
            "experiment": {
                "run": "PG",
                "checkpoint_freq": 1,
                "local_dir": dirpath
            }
        },
        resume=True)
    assert all(t.status == Trial.TERMINATED for t in trials2)
    cluster.shutdown()
开发者ID:robertnishihara,项目名称:ray,代码行数:58,代码来源:test_cluster.py


示例4: test_ls

def test_ls(start_ray, capsys, tmpdir):
    """This test captures output of list_trials."""
    experiment_name = "test_ls"
    experiment_path = os.path.join(str(tmpdir), experiment_name)
    num_samples = 2
    with capsys.disabled():
        tune.run_experiments({
            experiment_name: {
                "run": "__fake",
                "stop": {
                    "training_iteration": 1
                },
                "num_samples": num_samples,
                "local_dir": str(tmpdir)
            }
        })

    commands.list_trials(experiment_path, info_keys=("status", ))
    captured = capsys.readouterr().out.strip()
    lines = captured.split("\n")
    assert sum("TERMINATED" in line for line in lines) == num_samples
开发者ID:robertnishihara,项目名称:ray,代码行数:21,代码来源:test_commands.py


示例5: testTrialReuseEnabled

 def testTrialReuseEnabled(self):
     trials = run_experiments(
         {
             "foo": {
                 "run": MyResettableClass,
                 "num_samples": 4,
                 "config": {},
             }
         },
         reuse_actors=True,
         scheduler=FrequentPausesScheduler())
     self.assertEqual([t.last_result["num_resets"] for t in trials],
                      [1, 2, 3, 4])
开发者ID:robertnishihara,项目名称:ray,代码行数:13,代码来源:test_actor_reuse.py


示例6: test_lsx

def test_lsx(start_ray, capsys, tmpdir):
    """This test captures output of list_experiments."""
    project_path = str(tmpdir)
    num_experiments = 3
    for i in range(num_experiments):
        experiment_name = "test_lsx{}".format(i)
        with capsys.disabled():
            tune.run_experiments({
                experiment_name: {
                    "run": "__fake",
                    "stop": {
                        "training_iteration": 1
                    },
                    "num_samples": 1,
                    "local_dir": project_path
                }
            })

    commands.list_experiments(project_path, info_keys=("total_trials", ))
    captured = capsys.readouterr().out.strip()
    lines = captured.split("\n")
    assert sum("1" in line for line in lines) >= 3
开发者ID:robertnishihara,项目名称:ray,代码行数:22,代码来源:test_commands.py


示例7: test_cluster_down_full

def test_cluster_down_full(start_connected_cluster, tmpdir):
    """Tests that run_experiment restoring works on cluster shutdown."""
    cluster = start_connected_cluster
    dirpath = str(tmpdir)

    exp1_args = dict(
        run="__fake",
        stop=dict(training_iteration=3),
        local_dir=dirpath,
        checkpoint_freq=1)
    exp2_args = dict(run="__fake", stop=dict(training_iteration=3))
    exp3_args = dict(
        run="__fake",
        stop=dict(training_iteration=3),
        config=dict(mock_error=True))
    exp4_args = dict(
        run="__fake",
        stop=dict(training_iteration=3),
        config=dict(mock_error=True),
        checkpoint_freq=1)
    all_experiments = {
        "exp1": exp1_args,
        "exp2": exp2_args,
        "exp3": exp3_args,
        "exp4": exp4_args
    }

    tune.run_experiments(all_experiments, raise_on_failed_trial=False)

    ray.shutdown()
    cluster.shutdown()
    cluster = _start_new_cluster()

    trials = tune.run_experiments(
        all_experiments, resume=True, raise_on_failed_trial=False)
    assert len(trials) == 4
    assert all(t.status in [Trial.TERMINATED, Trial.ERROR] for t in trials)
    cluster.shutdown()
开发者ID:robertnishihara,项目名称:ray,代码行数:38,代码来源:test_cluster.py


示例8: f

#!/usr/bin/env python

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys

import ray
from ray.tune import register_trainable, run_experiments


def f(config, reporter):
    reporter(timesteps_total=1)


if __name__ == "__main__":
    ray.init()
    register_trainable("my_class", f)
    run_experiments({
        "test": {
            "run": "my_class",
            "stop": {
                "training_iteration": 1
            }
        }
    })
    assert 'ray.rllib' not in sys.modules, "RLlib should not be imported"
开发者ID:jamescasbon,项目名称:ray,代码行数:28,代码来源:dependency_test.py


示例9: grid_search

    args, _ = parser.parse_known_args()

    mnist_spec = {
        'run': train,
        'num_samples': 10,
        'stop': {
            'mean_accuracy': 0.99,
            'timesteps_total': 600,
        },
        'config': {
            'activation': grid_search(['relu', 'elu', 'tanh']),
        },
    }

    if args.smoke_test:
        mnist_spec['stop']['training_iteration'] = 2
        mnist_spec['num_samples'] = 1

    ray.init()

    from ray.tune.schedulers import AsyncHyperBandScheduler
    run_experiments(
        {
            'tune_mnist_test': mnist_spec
        },
        scheduler=AsyncHyperBandScheduler(
            time_attr="timesteps_total",
            reward_attr="mean_accuracy",
            max_t=600,
        ))
开发者ID:jamescasbon,项目名称:ray,代码行数:30,代码来源:tune_mnist_async_hyperband.py


示例10: register_trainable

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--smoke-test', action='store_true', help='Finish quickly for testing')
    args, _ = parser.parse_known_args()

    register_trainable("my_class", TrainMNIST)
    mnist_spec = {
        'run': 'my_class',
        'stop': {
          'mean_accuracy': 0.99,
          'time_total_s': 600,
        },
        'config': {
            'learning_rate': lambda spec:  10 ** np.random.uniform(-5, -3),
            'activation': grid_search(['relu', 'elu', 'tanh']),
        },
        "repeat": 10,
    }

    if args.smoke_test:
        mnist_spec['stop']['training_iteration'] = 2
        mnist_spec['repeat'] = 2

    ray.init()
    hyperband = HyperBandScheduler(
        time_attr="timesteps_total", reward_attr="mean_accuracy",
        max_t=100)

    run_experiments(
        {'mnist_hyperband_test': mnist_spec}, scheduler=hyperband)
开发者ID:adgirish,项目名称:ray,代码行数:30,代码来源:tune_mnist_ray_hyperband.py


示例11: range

    for i in range(args.num_iters):
        if i % 10 == 0:
            start = time.time()
            loss = sgd.step(fetch_stats=True)["loss"]
            metrics = sgd.foreach_model(lambda model: model.get_metrics())
            acc = [m["accuracy"] for m in metrics]
            print("Iter", i, "loss", loss, "accuracy", acc)
            print("Time per iteration", time.time() - start)
            assert len(set(acc)) == 1, ("Models out of sync", acc)
            reporter(timesteps_total=i, mean_loss=loss, mean_accuracy=acc[0])
        else:
            sgd.step()


if __name__ == "__main__":
    args = parser.parse_args()
    ray.init(redis_address=args.redis_address)

    if args.tune:
        run_experiments({
            "mnist_sgd": {
                "run": train_mnist,
                "config": {
                    "args": args,
                },
            },
        })
    else:
        train_mnist({"args": args}, lambda **kw: None)
开发者ID:robertnishihara,项目名称:ray,代码行数:29,代码来源:mnist_example.py


示例12: register_env

    result = subprocess.check_output(
        "ps aux | grep '{}' | grep -v grep || true".format(UNIQUE_CMD),
        shell=True)
    return result


if __name__ == "__main__":
    register_env("subproc", lambda config: EnvWithSubprocess(config))
    ray.init()
    assert os.path.exists(UNIQUE_FILE_0)
    assert os.path.exists(UNIQUE_FILE_1)
    assert not leaked_processes()
    run_experiments({
        "demo": {
            "run": "PG",
            "env": "subproc",
            "num_samples": 1,
            "config": {
                "num_workers": 1,
            },
            "stop": {
                "training_iteration": 1
            },
        },
    })
    leaked = leaked_processes()
    assert not leaked, "LEAKED PROCESSES: {}".format(leaked)
    assert not os.path.exists(UNIQUE_FILE_0), "atexit handler not called"
    assert not os.path.exists(UNIQUE_FILE_1), "atexit handler not called"
    print("OK")
开发者ID:jamescasbon,项目名称:ray,代码行数:30,代码来源:test_env_with_subprocess.py


示例13: SigOptSearch

        },
        {
            'name': 'height',
            'type': 'int',
            'bounds': {
                'min': -100,
                'max': 100
            },
        },
    ]

    config = {
        "my_exp": {
            "run": "exp",
            "num_samples": 10 if args.smoke_test else 1000,
            "config": {
                "iterations": 100,
            },
            "stop": {
                "timesteps_total": 100
            },
        }
    }
    algo = SigOptSearch(
        space,
        name="SigOpt Example Experiment",
        max_concurrent=1,
        reward_attr="neg_mean_loss")
    scheduler = AsyncHyperBandScheduler(reward_attr="neg_mean_loss")
    run_experiments(config, search_alg=algo, scheduler=scheduler)
开发者ID:robertnishihara,项目名称:ray,代码行数:30,代码来源:sigopt_example.py


示例14: PopulationBasedTraining

        num_cpus=10,
        num_gpus=0,
        resources={str(i): 2},
        object_store_memory=object_store_memory,
        redis_max_memory=redis_max_memory)
ray.init(redis_address=cluster.redis_address)

# Run the workload.

pbt = PopulationBasedTraining(
    time_attr="training_iteration",
    reward_attr="episode_reward_mean",
    perturbation_interval=10,
    hyperparam_mutations={
        "lr": [0.1, 0.01, 0.001, 0.0001],
    })

run_experiments(
    {
        "pbt_test": {
            "run": "PG",
            "env": "CartPole-v0",
            "num_samples": 8,
            "config": {
                "lr": 0.01,
            },
        }
    },
    scheduler=pbt,
    verbose=False)
开发者ID:robertnishihara,项目名称:ray,代码行数:30,代码来源:pbt.py


示例15: run_experiments

run_experiments({
    "carla-dqn": {
        "run": "DQN",
        "env": "carla_env",
        "resources": {"cpu": 4, "gpu": 1},
        "config": {
            "env_config": env_config,
            "model": {
                "custom_model": "carla",
                "custom_options": {
                    "image_shape": [
                        80, 80,
                        lambda spec: spec.config.env_config.framestack * (
                            spec.config.env_config.use_depth_camera and 1 or 3
                        ),
                    ],
                },
                "conv_filters": [
                    [16, [8, 8], 4],
                    [32, [4, 4], 2],
                    [512, [10, 10], 1],
                ],
            },
            "timesteps_per_iteration": 100,
            "learning_starts": 1000,
            "schedule_max_timesteps": 100000,
            "gamma": 0.8,
            "tf_session_args": {
              "gpu_options": {"allow_growth": True},
            },
        },
    },
})
开发者ID:adgirish,项目名称:ray,代码行数:33,代码来源:train_dqn.py


示例16: run_experiments

        config = {
            "num_gpus": 0,
            "num_workers": 2,
            "optimizer": {
                "num_replay_buffer_shards": 1,
            },
            "min_iter_time_s": 3,
            "buffer_size": 1000,
            "learning_starts": 1000,
            "train_batch_size": 128,
            "sample_batch_size": 32,
            "target_network_update_freq": 500,
            "timesteps_per_iteration": 1000,
        }
        group = True
    else:
        config = {}
        group = False

    ray.init()
    run_experiments({
        "two_step": {
            "run": args.run,
            "env": "grouped_twostep" if group else TwoStepGame,
            "stop": {
                "timesteps_total": args.stop,
            },
            "config": config,
        },
    })
开发者ID:robertnishihara,项目名称:ray,代码行数:30,代码来源:twostep_game.py


示例17: step

    def step(self, action):
        assert action in [0, 1], action
        if action == 0 and self.cur_pos > 0:
            self.cur_pos -= 1
        elif action == 1:
            self.cur_pos += 1
        done = self.cur_pos >= self.end_pos
        return [self.cur_pos], 1 if done else 0, done, {}


if __name__ == "__main__":
    # Can also register the env creator function explicitly with:
    # register_env("corridor", lambda config: SimpleCorridor(config))
    ray.init()
    run_experiments({
        "demo": {
            "run": "PPO",
            "env": SimpleCorridor,  # or "corridor" if registered above
            "stop": {
                "timesteps_total": 10000,
            },
            "config": {
                "lr": grid_search([1e-2, 1e-4, 1e-6]),  # try different lrs
                "num_workers": 1,  # parallelism
                "env_config": {
                    "corridor_length": 5,
                },
            },
        },
    })
开发者ID:robertnishihara,项目名称:ray,代码行数:30,代码来源:custom_env.py


示例18: Cluster

cluster = Cluster()
for i in range(num_nodes):
    cluster.add_node(
        redis_port=6379 if i == 0 else None,
        num_redis_shards=num_redis_shards if i == 0 else None,
        num_cpus=20,
        num_gpus=0,
        resources={str(i): 2},
        object_store_memory=object_store_memory,
        redis_max_memory=redis_max_memory)
ray.init(redis_address=cluster.redis_address)

# Run the workload.

run_experiments({
    "apex": {
        "run": "APEX",
        "env": "Pong-v0",
        "config": {
            "num_workers": 8,
            "num_gpus": 0,
            "buffer_size": 10000,
            "learning_starts": 0,
            "sample_batch_size": 1,
            "train_batch_size": 1,
            "min_iter_time_s": 10,
            "timesteps_per_iteration": 10,
        },
    }
})
开发者ID:robertnishihara,项目名称:ray,代码行数:30,代码来源:apex.py


示例19: open

        with open(checkpoint_path) as f:
            self.timestep = json.loads(f.read())["timestep"]


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--smoke-test", action="store_true", help="Finish quickly for testing")
    args, _ = parser.parse_known_args()
    ray.init()

    # Hyperband early stopping, configured with `episode_reward_mean` as the
    # objective and `training_iteration` as the time unit,
    # which is automatically filled by Tune.
    hyperband = HyperBandScheduler(
        time_attr="training_iteration",
        reward_attr="episode_reward_mean",
        max_t=100)

    exp = Experiment(
        name="hyperband_test",
        run=MyTrainableClass,
        num_samples=20,
        stop={"training_iteration": 1 if args.smoke_test else 99999},
        config={
            "width": sample_from(lambda spec: 10 + int(90 * random.random())),
            "height": sample_from(lambda spec: int(100 * random.random()))
        })

    run_experiments(exp, scheduler=hyperband)
开发者ID:jamescasbon,项目名称:ray,代码行数:30,代码来源:hyperband_example.py


示例20: print

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--num-iters", type=int, default=2000)
    args = parser.parse_args()

    ray.init()
    trials = tune.run_experiments({
        "test": {
            "env": "CartPole-v0",
            "run": "PG",
            "stop": {
                "training_iteration": args.num_iters,
            },
            "config": {
                "callbacks": {
                    "on_episode_start": tune.function(on_episode_start),
                    "on_episode_step": tune.function(on_episode_step),
                    "on_episode_end": tune.function(on_episode_end),
                    "on_sample_end": tune.function(on_sample_end),
                    "on_train_result": tune.function(on_train_result),
                },
            },
        }
    })

    # verify custom metrics for integration tests
    custom_metrics = trials[0].last_result["custom_metrics"]
    print(custom_metrics)
    assert "pole_angle_mean" in custom_metrics
    assert "pole_angle_min" in custom_metrics
    assert "pole_angle_max" in custom_metrics
开发者ID:jamescasbon,项目名称:ray,代码行数:31,代码来源:custom_metrics_and_callbacks.py



注:本文中的ray.tune.run_experiments函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python rb.find_plugin_file函数代码示例发布时间:2022-05-26
下一篇:
Python policy_evaluator.PolicyEvaluator类代码示例发布时间:2022-05-26
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap