• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python training.EpochLogger类代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中simplelearn.training.EpochLogger的典型用法代码示例。如果您正苦于以下问题:Python EpochLogger类的具体用法?Python EpochLogger怎么用?Python EpochLogger使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。



在下文中一共展示了EpochLogger类的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: main


#.........这里部分代码省略.........
    #

    """
    def make_output_basename(args):
        assert_equal(os.path.splitext(args.output_prefix)[1], "")
        if os.path.isdir(args.output_prefix) and \
           not args.output_prefix.endswith('/'):
            args.output_prefix += '/'

        output_dir, output_prefix = os.path.split(args.output_prefix)
        if output_prefix != "":
            output_prefix = output_prefix + "_"

        output_prefix = os.path.join(output_dir, output_prefix)

        return "{}lr-{}_mom-{}_nesterov-{}_bs-{}".format(
            output_prefix,
            args.learning_rate,
            args.initial_momentum,
            args.nesterov,
            args.batch_size)
    """

    assert_equal(os.path.splitext(args.output_prefix)[1], "")
    if os.path.isdir(args.output_prefix) and not args.output_prefix.endswith("/"):
        args.output_prefix += "/"

    output_dir, output_prefix = os.path.split(args.output_prefix)
    if output_prefix != "":
        output_prefix = output_prefix + "_"

    output_prefix = os.path.join(output_dir, output_prefix)

    epoch_logger = EpochLogger(output_prefix + "SGD_nesterov.h5")

    # misclassification_node = Misclassification(output_node, label_node)
    # mcr_logger = LogsToLists()
    # training_stopper = StopsOnStagnation(max_epochs=10,
    #                                      min_proportional_decrease=0.0)

    misclassification_node = Misclassification(output_node, label_lookup_node)

    validation_loss_monitor = MeanOverEpoch(loss_node, callbacks=[])
    epoch_logger.subscribe_to("validation mean loss", validation_loss_monitor)

    validation_misclassification_monitor = MeanOverEpoch(
        misclassification_node, callbacks=[print_mcr, StopsOnStagnation(max_epochs=20, min_proportional_decrease=0.0)]
    )

    epoch_logger.subscribe_to("validation misclassification", validation_misclassification_monitor)

    # batch callback (monitor)
    # training_loss_logger = LogsToLists()
    training_loss_monitor = MeanOverEpoch(loss_node, callbacks=[print_loss])
    epoch_logger.subscribe_to("training mean loss", training_loss_monitor)

    training_misclassification_monitor = MeanOverEpoch(misclassification_node, callbacks=[])
    epoch_logger.subscribe_to("training misclassification %", training_misclassification_monitor)

    # epoch callbacks
    # validation_loss_logger = LogsToLists()

    def make_output_filename(args, best=False):
        basename = make_output_basename(args)
        return "{}{}.pkl".format(basename, "_best" if best else "")
开发者ID:paulfun92,项目名称:project_code,代码行数:66,代码来源:SGD_nesterov.py


示例2: main


#.........这里部分代码省略.........
                                                    args.learning_rate,
                                                    args.initial_momentum,
                                                    args.nesterov)
            parameter_updaters.append(parameter_updater)

            momentum_updaters.append(LinearlyInterpolatesOverEpochs(
                parameter_updater.momentum,
                args.final_momentum,
                args.epochs_to_momentum_saturation))

    #
    # Makes batch and epoch callbacks
    #

    def make_output_basename(args):
        assert_equal(os.path.splitext(args.output_prefix)[1], "")
        if os.path.isdir(args.output_prefix) and \
           not args.output_prefix.endswith('/'):
            args.output_prefix += '/'

        output_dir, output_prefix = os.path.split(args.output_prefix)
        if output_prefix != "":
            output_prefix = output_prefix + "_"

        output_prefix = os.path.join(output_dir, output_prefix)

        return "{}lr-{}_mom-{}_nesterov-{}_bs-{}".format(
            output_prefix,
            args.learning_rate,
            args.initial_momentum,
            args.nesterov,
            args.batch_size)

    epoch_logger = EpochLogger(make_output_basename(args) + "_log.h5")

    # misclassification_node = Misclassification(output_node, label_node)
    # mcr_logger = LogsToLists()
    # training_stopper = StopsOnStagnation(max_epochs=10,
    #                                      min_proportional_decrease=0.0)
    misclassification_node = Misclassification(output_node, label_node)

    validation_loss_monitor = MeanOverEpoch(loss_node, callbacks=[])
    epoch_logger.subscribe_to('validation mean loss', validation_loss_monitor)

    validation_misclassification_monitor = MeanOverEpoch(
        misclassification_node,
        callbacks=[print_mcr,
                   StopsOnStagnation(max_epochs=10,
                                     min_proportional_decrease=0.0)])

    epoch_logger.subscribe_to('validation misclassification',
                              validation_misclassification_monitor)

    # batch callback (monitor)
    # training_loss_logger = LogsToLists()
    training_loss_monitor = MeanOverEpoch(loss_node, callbacks=[print_loss])
    epoch_logger.subscribe_to('training mean loss', training_loss_monitor)

    training_misclassification_monitor = MeanOverEpoch(misclassification_node,
                                                       callbacks=[])
    epoch_logger.subscribe_to('training misclassification %',
                              training_misclassification_monitor)

    # epoch callbacks
    # validation_loss_logger = LogsToLists()
开发者ID:paulfun92,项目名称:simplelearn,代码行数:66,代码来源:mnist_fully_connected.py


示例3: main


#.........这里部分代码省略.........
                     parameter_updaters,
                     momentum_updaters)

    #
    # Makes batch and epoch callbacks
    #
    def make_output_filename(args, best=False):
            '''
            Constructs a filename that reflects the command-line params.
            '''
            assert_equal(os.path.splitext(args.output_prefix)[1], "")

            if os.path.isdir(args.output_prefix):
                output_dir, output_prefix = args.output_prefix, ""
            else:
                output_dir, output_prefix = os.path.split(args.output_prefix)
                assert_true(os.path.isdir(output_dir))

            if output_prefix != "":
                output_prefix = output_prefix + "_"

            output_prefix = os.path.join(output_dir, output_prefix)

            return ("%slr-%g_mom-%g_nesterov-%s_bs-%d%s.pkl" %
                    (output_prefix,
                     args.learning_rate,
                     args.initial_momentum,
                     args.nesterov,
                     args.batch_size,
                     "_best" if best else ""))


    # Set up the loggers
    epoch_logger = EpochLogger(make_output_filename(args) + "_log.h5")
    misclassification_node = Misclassification(output_node, label_lookup_node)

    validation_loss_monitor = MeanOverEpoch(loss_node, callbacks=[])
    epoch_logger.subscribe_to('validation mean loss', validation_loss_monitor)

    training_stopper = StopsOnStagnation(max_epochs=201,
                                             min_proportional_decrease=0.0)
    validation_misclassification_monitor = MeanOverEpoch(misclassification_node,
                                             callbacks=[print_misclassification_rate,
                                                        training_stopper])

    epoch_logger.subscribe_to('validation misclassification',
                                validation_misclassification_monitor)

    # batch callback (monitor)
    #training_loss_logger = LogsToLists()
    training_loss_monitor = MeanOverEpoch(loss_node,
                                          callbacks=[print_loss])
    epoch_logger.subscribe_to("training loss", training_loss_monitor)

    training_misclassification_monitor = MeanOverEpoch(misclassification_node,
                                                       callbacks=[])
    epoch_logger.subscribe_to('training misclassification %',
                              training_misclassification_monitor)

    epoch_timer = EpochTimer2()
    epoch_logger.subscribe_to('epoch duration', epoch_timer)
#    epoch_logger.subscribe_to('epoch time',
 #                             epoch_timer)
    #################

开发者ID:paulfun92,项目名称:project_code,代码行数:65,代码来源:cifar10_conv3.py


示例4: main


#.........这里部分代码省略.........
                output_dir, output_prefix = args.output_prefix, ""
            else:
                output_dir, output_prefix = os.path.split(args.output_prefix)
                assert_true(os.path.isdir(output_dir))

            if output_prefix != "":
                output_prefix = output_prefix + "_"

            output_prefix = os.path.join(output_dir, output_prefix)

            return ("%slr-%g_mom-%g_nesterov-%s_bs-%d%s.pkl" %
                    (output_prefix,
                     args.learning_rate,
                     args.initial_momentum,
                     args.nesterov,
                     args.batch_size,
                     "_best" if best else ""))
    '''


    # Set up the loggers

    assert_equal(os.path.splitext(args.output_prefix)[1], "")
    if os.path.isdir(args.output_prefix) and \
       not args.output_prefix.endswith('/'):
        args.output_prefix += '/'

    output_dir, output_prefix = os.path.split(args.output_prefix)
    if output_prefix != "":
        output_prefix = output_prefix + "_"

    output_prefix = os.path.join(output_dir, output_prefix)

    epoch_logger = EpochLogger(output_prefix + "S2GD_plus.h5")


    misclassification_node = Misclassification(output_node, label_lookup_node)

    validation_loss_monitor = MeanOverEpoch(loss_node, callbacks=[])
    epoch_logger.subscribe_to('validation mean loss', validation_loss_monitor)

    training_stopper = StopsOnStagnation(max_epochs=20,
                                             min_proportional_decrease=0.0)
    validation_misclassification_monitor = MeanOverEpoch(misclassification_node,
                                             callbacks=[print_misclassification_rate,
                                                        training_stopper])

    epoch_logger.subscribe_to('validation misclassification',
                                validation_misclassification_monitor)

    # batch callback (monitor)
    #training_loss_logger = LogsToLists()
    training_loss_monitor = MeanOverEpoch(loss_node,
                                          callbacks=[print_loss])
    epoch_logger.subscribe_to("training loss", training_loss_monitor)

    training_misclassification_monitor = MeanOverEpoch(misclassification_node,
                                                       callbacks=[])
    epoch_logger.subscribe_to('training misclassification %',
                              training_misclassification_monitor)

    epoch_timer = EpochTimer2()
    epoch_logger.subscribe_to('epoch duration', epoch_timer)
#    epoch_logger.subscribe_to('epoch time',
 #                             epoch_timer)
    #################
开发者ID:paulfun92,项目名称:project_code,代码行数:67,代码来源:S2GD_plus.py


示例5: main


#.........这里部分代码省略.........
    print(grads)
    print(grads.shape)

    #
    # Makes batch and epoch callbacks
    #
    def make_output_filename(args, best=False):
            '''
            Constructs a filename that reflects the command-line params.
            '''
            assert_equal(os.path.splitext(args.output_prefix)[1], "")

            if os.path.isdir(args.output_prefix):
                output_dir, output_prefix = args.output_prefix, ""
            else:
                output_dir, output_prefix = os.path.split(args.output_prefix)
                assert_true(os.path.isdir(output_dir))

            if output_prefix != "":
                output_prefix = output_prefix + "_"

            output_prefix = os.path.join(output_dir, output_prefix)

            return ("%slr-%g_mom-%g_nesterov-%s_bs-%d%s.pkl" %
                    (output_prefix,
                     args.learning_rate,
                     args.initial_momentum,
                     not args.no_nesterov,
                     args.batch_size,
                     "_best" if best else ""))


    # Set up the loggers
    epoch_logger = EpochLogger(make_output_filename(args) + "_log.h5")
    misclassification_node = Misclassification(output_node, label_lookup_node)

    validation_loss_monitor = MeanOverEpoch(loss_node, callbacks=[])
    epoch_logger.subscribe_to('validation mean loss', validation_loss_monitor)

    training_stopper = StopsOnStagnation(max_epochs=100,
                                             min_proportional_decrease=0.0)
    validation_misclassification_monitor = MeanOverEpoch(misclassification_node,
                                             callbacks=[print_misclassification_rate,
                                                        training_stopper])

    epoch_logger.subscribe_to('validation misclassification',
                                validation_misclassification_monitor)

    # batch callback (monitor)
    #training_loss_logger = LogsToLists()
    training_loss_monitor = MeanOverEpoch(loss_node,
                                          callbacks=[print_loss])
    epoch_logger.subscribe_to("training loss", training_loss_monitor)

    training_misclassification_monitor = MeanOverEpoch(misclassification_node,
                                                       callbacks=[])
    epoch_logger.subscribe_to('training misclassification %',
                              training_misclassification_monitor)

    epoch_timer = EpochTimer()
#    epoch_logger.subscribe_to('epoch time',
 #                             epoch_timer)
    #################


    model = SerializableModel([input_indices_symbolic], [output_node])
开发者ID:paulfun92,项目名称:project_code,代码行数:67,代码来源:LBFGS_mnist_conv3.py


示例6: main


#.........这里部分代码省略.........
                                                    args.learning_rate,
                                                    args.initial_momentum,
                                                    args.nesterov)
            parameter_updaters.append(parameter_updater)

            momentum_updaters.append(LinearlyInterpolatesOverEpochs(
                parameter_updater.momentum,
                args.final_momentum,
                args.epochs_to_momentum_saturation))

    #
    # Makes batch and epoch callbacks
    #

    def make_output_basename(args):
        assert_equal(os.path.splitext(args.output_prefix)[1], "")
        if os.path.isdir(args.output_prefix) and \
           not args.output_prefix.endswith('/'):
            args.output_prefix += '/'

        output_dir, output_prefix = os.path.split(args.output_prefix)
        if output_prefix != "":
            output_prefix = output_prefix + "_"

        output_prefix = os.path.join(output_dir, output_prefix)

        return "{}lr-{}_mom-{}_nesterov-{}_bs-{}".format(
            output_prefix,
            args.learning_rate,
            args.initial_momentum,
            args.nesterov,
            args.batch_size)

    epoch_logger = EpochLogger(make_output_basename(args) + "_log.h5")

    # misclassification_node = Misclassification(output_node, label_node)
    # mcr_logger = LogsToLists()
    # training_stopper = StopsOnStagnation(max_epochs=10,
    #                                      min_proportional_decrease=0.0)
    misclassification_node = Misclassification(output_node, label_node)

    validation_loss_monitor = MeanOverEpoch(loss_node, callbacks=[])
    epoch_logger.subscribe_to('validation mean loss', validation_loss_monitor)

    validation_misclassification_monitor = MeanOverEpoch(
        misclassification_node,
        callbacks=[print_mcr,
                   StopsOnStagnation(max_epochs=10,
                                     min_proportional_decrease=0.0)])

    epoch_logger.subscribe_to('validation misclassification',
                              validation_misclassification_monitor)

    # batch callback (monitor)
    # training_loss_logger = LogsToLists()
    training_loss_monitor = MeanOverEpoch(loss_node, callbacks=[print_loss])
    epoch_logger.subscribe_to('training mean loss', training_loss_monitor)

    training_misclassification_monitor = MeanOverEpoch(misclassification_node,
                                                       callbacks=[])
    epoch_logger.subscribe_to('training misclassification %',
                              training_misclassification_monitor)

    # epoch callbacks
    # validation_loss_logger = LogsToLists()
开发者ID:paulfun92,项目名称:project_code,代码行数:66,代码来源:SGD_mnist_fully_connected.py


示例7: main


#.........这里部分代码省略.........
                     parameter_updaters,
                     momentum_updaters)

    #
    # Makes batch and epoch callbacks
    #
    def make_output_filename(args, best=False):
            '''
            Constructs a filename that reflects the command-line params.
            '''
            assert_equal(os.path.splitext(args.output_prefix)[1], "")

            if os.path.isdir(args.output_prefix):
                output_dir, output_prefix = args.output_prefix, ""
            else:
                output_dir, output_prefix = os.path.split(args.output_prefix)
                assert_true(os.path.isdir(output_dir))

            if output_prefix != "":
                output_prefix = output_prefix + "_"

            output_prefix = os.path.join(output_dir, output_prefix)

            return ("%slr-%g_mom-%g_nesterov-%s_bs-%d%s.pkl" %
                    (output_prefix,
                     args.learning_rate,
                     args.initial_momentum,
                     not args.no_nesterov,
                     args.batch_size,
                     "_best" if best else ""))


    # Set up the loggers
    epoch_logger = EpochLogger(make_output_filename(args) + "_log.h5")
    misclassification_node = Misclassification(output_node, label_node)

    validation_loss_monitor = MeanOverEpoch(loss_node, callbacks=[])
    epoch_logger.subscribe_to('validation mean loss', validation_loss_monitor)

    training_stopper = StopsOnStagnation(max_epochs=100,
                                             min_proportional_decrease=0.0)
    validation_misclassification_monitor = MeanOverEpoch(misclassification_node,
                                             callbacks=[print_misclassification_rate,
                                                        training_stopper])

    epoch_logger.subscribe_to('validation misclassification',
                                validation_misclassification_monitor)

    # batch callback (monitor)
    #training_loss_logger = LogsToLists()
    training_loss_monitor = MeanOverEpoch(loss_node,
                                          callbacks=[print_loss])
    epoch_logger.subscribe_to("training loss", training_loss_monitor)

    training_misclassification_monitor = MeanOverEpoch(misclassification_node,
                                                       callbacks=[])
    epoch_logger.subscribe_to('training misclassification %',
                              training_misclassification_monitor)

    epoch_timer = EpochTimer()
#    epoch_logger.subscribe_to('epoch time',
 #                             epoch_timer)
    #################


    model = SerializableModel([image_uint8_node], [output_node])
开发者ID:paulfun92,项目名称:project_code,代码行数:67,代码来源:GD_mnist_conv.py


示例8: main


#.........这里部分代码省略.........
                                        sparse_init_counts,
                                        args.dropout_include_rates,
                                        rng,
                                        theano_rng)

    loss_node = CrossEntropy(output_node, label_lookup_node)
    loss_sum = loss_node.output_symbol.mean()
    max_epochs = 10000
    gradient = theano.gradient.grad(loss_sum, params_flat)

    #
    # Makes batch and epoch callbacks
    #

    def make_output_basename(args):
        assert_equal(os.path.splitext(args.output_prefix)[1], "")
        if os.path.isdir(args.output_prefix) and \
           not args.output_prefix.endswith('/'):
            args.output_prefix += '/'

        output_dir, output_prefix = os.path.split(args.output_prefix)
        if output_prefix != "":
            output_prefix = output_prefix + "_"

        output_prefix = os.path.join(output_dir, output_prefix)

        return "{}lr-{}_mom-{}_nesterov-{}_bs-{}".format(
            output_prefix,
            args.learning_rate,
            args.initial_momentum,
            args.nesterov,
            args.batch_size)

    epoch_logger = EpochLogger(make_output_basename(args) + "_log.h5")

    # misclassification_node = Misclassification(output_node, label_node)
    # mcr_logger = LogsToLists()
    # training_stopper = StopsOnStagnation(max_epochs=10,
    #                                      min_proportional_decrease=0.0)

    misclassification_node = Misclassification(output_node, label_lookup_node)

    validation_loss_monitor = MeanOverEpoch(loss_node, callbacks=[])
    epoch_logger.subscribe_to('validation mean loss', validation_loss_monitor)

    validation_misclassification_monitor = MeanOverEpoch(
        misclassification_node,
        callbacks=[print_mcr,
                   StopsOnStagnation(max_epochs=100,
                                     min_proportional_decrease=0.0)])

    epoch_logger.subscribe_to('validation misclassification',
                              validation_misclassification_monitor)

    # batch callback (monitor)
    # training_loss_logger = LogsToLists()
    training_loss_monitor = MeanOverEpoch(loss_node, callbacks=[print_loss])
    epoch_logger.subscribe_to('training mean loss', training_loss_monitor)

    training_misclassification_monitor = MeanOverEpoch(misclassification_node,
                                                       callbacks=[])
    epoch_logger.subscribe_to('training misclassification %',
                              training_misclassification_monitor)

    # epoch callbacks
    # validation_loss_logger = LogsToLists()
开发者ID:paulfun92,项目名称:project_code,代码行数:67,代码来源:LBFGS_fully_connected_CIFAR10.py



注:本文中的simplelearn.training.EpochLogger类示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python utils.safe_izip函数代码示例发布时间:2022-05-27
下一篇:
Python mnist.load_mnist函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap