I am currently training a model using bls2017.py from this library, which is written in tensorflow 1 (just thought I would share the code to provide context but I don't think it's necessary to answer this question).
This uses tf.train.MonitoredSession
to train the compression model:
with tf.train.MonitoredTrainingSession(
hooks=hooks, checkpoint_dir=args.checkpoint_dir,
save_checkpoint_secs=300, save_summaries_secs=60) as sess:
while not sess.should_stop():
sess.run(train_op)
Now I've finished training the model. What I now want to do is load the trained model and fine tune the model only on a small subset of weights using a smaller dataset. This means that I want to freeze most layers/only train the model with respect to the layers I want.
I have primarily worked with keras and am not sure how to load the model and fine tune on a subset of weights using tensorflow, especially dealing with tensorflow graphs. Most of the examples I have come across (e.g. Tensorflow: restoring a graph and model then running evaluation on a single image) deal more with loading a trained model for evaluation vs. retraining/fine-tuning). This is all I have so far:
latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir)
tf.train.Saver().restore(sess, save_path=latest)
I realize this is a big question but I am quite stuck and would really appreciate any help.
Also, here's a long code snippet of the training code in case it is of use:
# Get training patch from dataset.
x = train_dataset.make_one_shot_iterator().get_next()
# Instantiate model.
analysis_transform = AnalysisTransform(args.num_filters)
entropy_bottleneck = tfc.EntropyBottleneck()
synthesis_transform = SynthesisTransform(args.num_filters)
# Build autoencoder.
y = analysis_transform(x)
y_tilde, likelihoods = entropy_bottleneck(y, training=True)
x_tilde = synthesis_transform(y_tilde)
# Total number of bits divided by number of pixels.
train_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels)
# Mean squared error across pixels.
train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde))
# Multiply by 255^2 to correct for rescaling.
train_mse *= 255 ** 2
# The rate-distortion cost.
train_loss = args.lmbda * train_mse + train_bpp
# Minimize loss and auxiliary loss, and execute update op.
step = tf.train.create_global_step()
main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
main_step = main_optimizer.minimize(train_loss, global_step=step)
aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
aux_step = aux_optimizer.minimize(entropy_bottleneck.losses[0])
train_op = tf.group(main_step, aux_step, entropy_bottleneck.updates[0])
hooks = [
tf.train.StopAtStepHook(last_step=args.last_step),
tf.train.NanTensorHook(train_loss),
]
with tf.train.MonitoredTrainingSession(
hooks=hooks, checkpoint_dir=args.checkpoint_dir,
save_checkpoint_secs=300, save_summaries_secs=60) as sess:
while not sess.should_stop():
sess.run(train_op)
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…