• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

R语言 Keras Training Flags

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

在需要经常进行调参的情况下,可以使用 Training Flags 来快速变换参数,比起直接修改模型参数来得快而且不易出错。

https://tensorflow.rstudio.com/tools/training_flags.html

使用 flags()

library(keras)

FLAGS <- flags(
  flag_integer("dense_units1", 128),
  flag_numeric("dropout1", 0.4),
  flag_integer("dense_units2", 128),
  flag_numeric("dropout2", 0.3),
  flag_integer("epochs", 30),
  flag_integer("batch_size", 128),
  flag_numeric("learning_rate", 0.001)
)
input <- layer_input(shape = c(784))
predictions <- input %>% 
  layer_dense(units = FLAGS$dense_units1, activation = \'relu\') %>%
  layer_dropout(rate = FLAGS$dropout1) %>%
  layer_dense(units = FLAGS$dense_units2, activation = \'relu\') %>%
  layer_dropout(rate = FLAGS$dropout2) %>%
  layer_dense(units = 10, activation = \'softmax\')

model <- keras_model(input, predictions) %>% compile(
  loss = \'categorical_crossentropy\',
  optimizer = optimizer_rmsprop(lr = FLAGS$learning_rate),
  metrics = c(\'accuracy\')
)

history <- model %>% fit(
  x_train, y_train,
  batch_size = FLAGS$batch_size,
  epochs = FLAGS$epochs,
  verbose = 1,
  validation_split = 0.2
)

flags()是 keras 库的函数,不是R语言本身的函数。

使用YAML文件

flags()可以搭配YAML文件使用。按照官方教程,以为是把参数定义在YAML文件里,然后使用flags(file="flags.yml")直接读入。但是发现这样行不通,flags(file="flags.yml")得到的是一个空list。后来发现可能得这样使用才是正确的:

FLAGS <- flags(file = "flags.yml",
  flag_integer("dense_units1", 128,  "Dense units in first layer"),
  flag_numeric("dropout1",     0.4,  "Dropout after first layer"),
  flag_integer("epochs",        30,  "Number of epochs to train for")
)

flags.yml 中的参数优先,会覆盖掉flags()里的定义,也就是说,如果 flags.yml 里面是这样定义的:

dense_units1: 256
dropout1: 0.4
epochs: 30

那么,dense_units1这个参数的值是 256,而不是 128。

下面这种用法不正确,

FLAGS <- flags(file = "flags.yml",
)

会得到一个空list。可以认为,flags.yml其实是用来覆盖或者说修改flags()里面已有的参数定义。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
R语言进行广州租房可视化发布时间:2022-07-18
下一篇:
基于SPSS Moderler和R语言的数据挖掘宽表处理发布时间:2022-07-18
热门推荐
热门话题
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap