We have three networks, an Encoder,
a Generator, and a Discriminator.
The Encoder learns to map input x onto z space (latent space)
The Generator learns to generate x from z space
The Discriminator learns to discriminate whether the image being put in is real, or generated
Diagram of basic network input and output
l_x_tilde and l_x here become layers of high level features that the discriminator learns.
we train the network to minimize the difference between the high level features of x and x_tilde
This is basically an autoencoder that works on high level features rather than pixels
Adding this autoencoder to a GAN helps to stabilize the GAN
Training
Train Encoder on minimization of:
kullback_leibler_loss(z_x, gaussian)
mean_squared_error(l_x_tilde_, l_x)
Train Generator on minimization of:
kullback_leibler_loss(z_x, gaussian)
mean_squared_error(l_x_tilde_, l_x)
-1*log(d_x_p)
Train Discriminator on minimization of:
-1*log(d_x) + log(1 - d_x_p)
# Import all of our packages
import os
import numpy as np
import prettytensor as pt
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as plt
from deconv import deconv2d
import IPython.display
import math
import tqdm # making loops prettier
import h5py # for reading our dataset
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed
%matplotlib inline
Parameters
dim1 = 64 # first dimension of input data
dim2 = 64 # second dimension of input data
dim3 = 3 # third dimension of input data (colors)
batch_size = 32 # size of batches to use (per GPU)
hidden_size = 2048 # size of hidden (z) layer to use
num_examples = 60000 # how many examples are in your training set
num_epochs = 10000 # number of epochs to run
### we can train our different networks with different learning rates if we want to
e_learning_rate = 1e-3
g_learning_rate = 1e-3
d_learning_rate = 1e-3
Which GPUs are we using?
Set gpus to a list of the GPUs you're using. The network will then split up the work between those gpus
gpus = [2] # Here I set CUDA to only see one GPU
os.environ["CUDA_VISIBLE_DEVICES"]=','.join([str(i) for i in gpus])
num_gpus = len(gpus) # number of GPUs to use
Reading the dataset from HDF5 format
open `makedataset.ipynb' for instructions on how to build the dataset
with h5py.File(''.join(['datasets/faces_dataset_new.h5']), 'r') as hf:
faces = hf['images'].value
headers = hf['headers'].value
labels = hf['label_input'].value
# Normalize the dataset between 0 and 1
faces = (faces/255.)
# Just taking a look and making sure everything works
plt.imshow(np.reshape(faces[1], (64,64,3)), interpolation='nearest')
<matplotlib.image.AxesImage at 0x7fe6bc24ce50>
# grab the faces back out after we've flattened them
def create_image(im):
return np.reshape(im,(dim1,dim2,dim3))
# Lets just take a look at our channels
cm = plt.cm.hot
test_face = faces[0].reshape(dim1,dim2,dim3)
fig, ax = plt.subplots(nrows=1,ncols=4, figsize=(20,8))
ax[0].imshow(create_image(test_face), interpolation='nearest')
ax[1].imshow(create_image(test_face)[:,:,0], interpolation='nearest', cmap=cm)
ax[2].imshow(create_image(test_face)[:,:,1], interpolation='nearest', cmap=cm)
ax[3].imshow(create_image(test_face)[:,:,2], interpolation='nearest', cmap=cm)
<matplotlib.image.AxesImage at 0x7fe6ad0a9150>
A data iterator for batching (drawn up by Luke Metz)
def encoder(X):
'''Create encoder network.
Args:
x: a batch of flattened images [batch_size, 28*28]
Returns:
A tensor that expresses the encoder network
# The transformation is parametrized and can be learned.
# returns network output, mean, setd
'''
lay_end = (pt.wrap(X).
reshape([batch_size, dim1, dim2, dim3]).
conv2d(5, 64, stride=2).
conv2d(5, 128, stride=2).
conv2d(5, 256, stride=2).
flatten())
z_mean = lay_end.fully_connected(hidden_size, activation_fn=None)
z_log_sigma_sq = lay_end.fully_connected(hidden_size, activation_fn=None)
return z_mean, z_log_sigma_sq
def generator(Z):
'''Create generator network.
If input tensor is provided then decodes it, otherwise samples from
a sampled vector.
Args:
x: a batch of vectors to decode
Returns:
A tensor that expresses the generator network
'''
return (pt.wrap(Z).
fully_connected(8*8*256).reshape([batch_size, 8, 8, 256]). #(128, 4 4, 256)
deconv2d(5, 256, stride=2).
deconv2d(5, 128, stride=2).
deconv2d(5, 32, stride=2).
deconv2d(1, dim3, stride=1, activation_fn=tf.sigmoid).
flatten()
)
def discriminator(D_I):
''' A encodes
Create a network that discriminates between images from a dataset and
generated ones.
Args:
input: a batch of real images [batch, height, width, channels]
Returns:
A tensor that represents the network
'''
descrim_conv = (pt.wrap(D_I). # This is what we're descriminating
reshape([batch_size, dim1, dim2, dim3]).
conv2d(5, 32, stride=1).
conv2d(5, 128, stride=2).
conv2d(5, 256, stride=2).
conv2d(5, 256, stride=2).
flatten()
)
lth_layer= descrim_conv.fully_connected(1024, activation_fn=tf.nn.elu)# this is the lth layer
D =lth_layer.fully_connected(1, activation_fn=tf.nn.sigmoid) # this is the actual discrimination
return D, lth_layer
Defining the forward pass through the network
This function is based upon the inference function from tensorflows cifar tutorials
Notice I use with tf.variable_scope("enc"). This way, we can reuse these variables using reuse=True. We can also specify which variables to train using which error functions based upon the label enc
def inference(x):
"""
Run the models. Called inference because it does the same thing as tensorflow's cifar tutorial
"""
z_p = tf.random_normal((batch_size, hidden_size), 0, 1) # normal dist for GAN
eps = tf.random_normal((batch_size, hidden_size), 0, 1) # normal dist for VAE
with pt.defaults_scope(activation_fn=tf.nn.elu,
batch_normalize=True,
learned_moments_update_rate=0.0003,
variance_epsilon=0.001,
scale_after_normalization=True):
with tf.variable_scope("enc"):
z_x_mean, z_x_log_sigma_sq = encoder(x) # get z from the input
with tf.variable_scope("gen"):
z_x = tf.add(z_x_mean,
tf.mul(tf.sqrt(tf.exp(z_x_log_sigma_sq)), eps)) # grab our actual z
x_tilde = generator(z_x)
with tf.variable_scope("dis"):
_, l_x_tilde = discriminator(x_tilde)
with tf.variable_scope("gen", reuse=True):
x_p = generator(z_p)
with tf.variable_scope("dis", reuse=True):
d_x, l_x = discriminator(x) # positive examples
with tf.variable_scope("dis", reuse=True):
d_x_p, _ = discriminator(x_p)
return z_x_mean, z_x_log_sigma_sq, z_x, x_tilde, l_x_tilde, x_p, d_x, l_x, d_x_p, z_p
Loss - define our various loss functions
SSE - we don't actually use this loss (also its the MSE), its just to see how close x is to x_tilde
Basically we're taking a list of gradients from each tower, and averaging them together
def average_gradients(tower_grads):
"""Calculate the average gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over individual gradients. The inner list is over the gradient
calculation for each tower.
Returns:
List of pairs of (gradient, variable) where the gradient has been averaged
across all towers.
"""
average_grads = []
for grad_and_vars in zip(*tower_grads):
# Note that each grad_and_vars looks like the following:
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
grads = []
for g, _ in grad_and_vars:
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
grads.append(expanded_g)
# Average over the 'tower' dimension.
grad = tf.concat(0, grads)
grad = tf.reduce_mean(grad, 0)
# Keep in mind that the Variables are redundant because they are shared
# across towers. So .. we will just return the first tower's pointer to
# the Variable.
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
return average_grads
Plot network output
This is just my ugly function to regularly plot the output of my network - tensorboard would probably be a better option for this
# Make lists to save the losses to
# You should probably just be using tensorboard to do any visualization(or just use tensorboard...)
G_loss_list = []
D_loss_list = []
SSE_loss_list = []
KL_loss_list = []
LL_loss_list = []
dxp_list = []
dx_list = []
With your graph, define what a step is (needed for multi-gpu), and what your optimizers are for each of your networks
with graph.as_default():
#with tf.Graph().as_default(), tf.device('/cpu:0'):
# Create a variable to count number of train calls
global_step = tf.get_variable(
'global_step', [],
initializer=tf.constant_initializer(0), trainable=False)
# different optimizers are needed for different learning rates (using the same learning rate seems to work fine though)
lr_D = tf.placeholder(tf.float32, shape=[])
lr_G = tf.placeholder(tf.float32, shape=[])
lr_E = tf.placeholder(tf.float32, shape=[])
opt_D = tf.train.AdamOptimizer(lr_D, epsilon=1.0)
opt_G = tf.train.AdamOptimizer(lr_G, epsilon=1.0)
opt_E = tf.train.AdamOptimizer(lr_E, epsilon=1.0)
Run all of the functions we defined above
tower_grads_e defines the list of gradients for the encoder for each tower
For each GPU we grab parameters corresponding to each network, we then calculate the gradients, and add them to the twoers to be averaged
with graph.as_default():
# These are the lists of gradients for each tower
tower_grads_e = []
tower_grads_g = []
tower_grads_d = []
all_input = tf.placeholder(tf.float32, [batch_size*num_gpus, dim1*dim2*dim3])
KL_param = tf.placeholder(tf.float32)
LL_param = tf.placeholder(tf.float32)
G_param = tf.placeholder(tf.float32)
# Define the network for each GPU
for i in xrange(num_gpus):
with tf.device('/gpu:%d' % i):
with tf.name_scope('Tower_%d' % (i)) as scope:
# grab this portion of the input
next_batch = all_input[i*batch_size:(i+1)*batch_size,:]
# Construct the model
z_x_mean, z_x_log_sigma_sq, z_x, x_tilde, l_x_tilde, x_p, d_x, l_x, d_x_p, z_p = inference(next_batch)
# Calculate the loss for this tower
SSE_loss, KL_loss, D_loss, G_loss, LL_loss = loss(next_batch, x_tilde, z_x_log_sigma_sq, z_x_mean, d_x, d_x_p, l_x, l_x_tilde, dim1, dim2, dim3)
# specify loss to parameters
params = tf.trainable_variables()
E_params = [i for i in params if 'enc' in i.name]
G_params = [i for i in params if 'gen' in i.name]
D_params = [i for i in params if 'dis' in i.name]
# Calculate the losses specific to encoder, generator, decoder
L_e = tf.clip_by_value(KL_loss*KL_param + LL_loss, -100, 100)
L_g = tf.clip_by_value(LL_loss*LL_param+G_loss*G_param, -100, 100)
L_d = tf.clip_by_value(D_loss, -100, 100)
# Reuse variables for the next tower.
tf.get_variable_scope().reuse_variables()
# Calculate the gradients for the batch of data on this CIFAR tower.
grads_e = opt_E.compute_gradients(L_e, var_list = E_params)
grads_g = opt_G.compute_gradients(L_g, var_list = G_params)
grads_d = opt_D.compute_gradients(L_d, var_list = D_params)
# Keep track of the gradients across all towers.
tower_grads_e.append(grads_e)
tower_grads_g.append(grads_g)
tower_grads_d.append(grads_d)
Now lets average, and apply those gradients
with graph.as_default():
# Average the gradients
grads_e = average_gradients(tower_grads_e)
grads_g = average_gradients(tower_grads_g)
grads_d = average_gradients(tower_grads_d)
# apply the gradients with our optimizers
train_E = opt_E.apply_gradients(grads_e, global_step=global_step)
train_G = opt_G.apply_gradients(grads_g, global_step=global_step)
train_D = opt_D.apply_gradients(grads_d, global_step=global_step)
we calculate the sigmoid of how the network has been performing, and squash the learning rate using a sigmoid based on that. So if the discriminator has been winning, it's learning rate will be low, and if the generator is winning, it's learning rate will be lower on the next batch.
def sigmoid(x,shift,mult):
"""
Using this sigmoid to discourage one network overpowering the other
"""
return 1 / (1 + math.exp(-(x+shift)*mult))
fig, ax = plt.subplots(nrows=1,ncols=1, figsize=(18,4))
plt.plot(np.arange(0,1,.01), [sigmoid(i/100.,-.5,10) for i in range(100)])
ax.set_xlabel('Mean of Discriminator(Real) or Discriminator(Fake)')
ax.set_ylabel('Multiplier for learning rate')
plt.title('Squashing the Learning Rate to balance Discrim/Gen network performance')
<matplotlib.text.Text at 0x7fe065bc41d0>
total_batch = int(np.floor(num_examples / batch_size*num_gpus)) # how many batches are in an epoch
# We balance of generator and discriminators learning rate by using a sigmoid function,
# encouraging the generator and discriminator be about equal
d_real = .5
d_fake = .5
while epoch < num_epochs:
for i in tqdm.tqdm(range(total_batch)):
iter_ = data_iterator()
# balence gen and descrim
e_current_lr = e_learning_rate*sigmoid(np.mean(d_real),-.5,15)
g_current_lr = g_learning_rate*sigmoid(np.mean(d_real),-.5,15)
d_current_lr = d_learning_rate*sigmoid(np.mean(d_fake),-.5,15)
next_batches, _ = iter_.next()
请发表评论