Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
364 views
in Technique[技术] by (71.8m points)

python - How to load and fine tune a model on subset of weights using tensorflow

I am currently training a model using bls2017.py from this library, which is written in tensorflow 1 (just thought I would share the code to provide context but I don't think it's necessary to answer this question).

This uses tf.train.MonitoredSession to train the compression model:

  with tf.train.MonitoredTrainingSession(
      hooks=hooks, checkpoint_dir=args.checkpoint_dir,
      save_checkpoint_secs=300, save_summaries_secs=60) as sess:
    while not sess.should_stop():
      sess.run(train_op)

Now I've finished training the model. What I now want to do is load the trained model and fine tune the model only on a small subset of weights using a smaller dataset. This means that I want to freeze most layers/only train the model with respect to the layers I want.

I have primarily worked with keras and am not sure how to load the model and fine tune on a subset of weights using tensorflow, especially dealing with tensorflow graphs. Most of the examples I have come across (e.g. Tensorflow: restoring a graph and model then running evaluation on a single image) deal more with loading a trained model for evaluation vs. retraining/fine-tuning). This is all I have so far:

latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir)
tf.train.Saver().restore(sess, save_path=latest)

I realize this is a big question but I am quite stuck and would really appreciate any help.


Also, here's a long code snippet of the training code in case it is of use:

  # Get training patch from dataset.
  x = train_dataset.make_one_shot_iterator().get_next()

  # Instantiate model.
  analysis_transform = AnalysisTransform(args.num_filters)
  entropy_bottleneck = tfc.EntropyBottleneck()
  synthesis_transform = SynthesisTransform(args.num_filters)

  # Build autoencoder.
  y = analysis_transform(x)
  y_tilde, likelihoods = entropy_bottleneck(y, training=True)
  x_tilde = synthesis_transform(y_tilde)

  # Total number of bits divided by number of pixels.
  train_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels)

  # Mean squared error across pixels.
  train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde))
  # Multiply by 255^2 to correct for rescaling.
  train_mse *= 255 ** 2

  # The rate-distortion cost.
  train_loss = args.lmbda * train_mse + train_bpp

  # Minimize loss and auxiliary loss, and execute update op.
  step = tf.train.create_global_step()
  main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
  main_step = main_optimizer.minimize(train_loss, global_step=step)

  aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
  aux_step = aux_optimizer.minimize(entropy_bottleneck.losses[0])
  train_op = tf.group(main_step, aux_step, entropy_bottleneck.updates[0])

  hooks = [
      tf.train.StopAtStepHook(last_step=args.last_step),
      tf.train.NanTensorHook(train_loss),
  ]
  with tf.train.MonitoredTrainingSession(
      hooks=hooks, checkpoint_dir=args.checkpoint_dir,
      save_checkpoint_secs=300, save_summaries_secs=60) as sess:
    while not sess.should_stop():
      sess.run(train_op) 

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)
等待大神答复

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...