本文整理汇总了Python中neon.util.persist.save_obj函数的典型用法代码示例。如果您正苦于以下问题:Python save_obj函数的具体用法?Python save_obj怎么用?Python save_obj使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了save_obj函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: save_history
def save_history(self, epoch, model):
# if history > 1, this function will save the last N checkpoints
# where N is equal to self.history. The files will have the form
# of save_path with the epoch added to the filename before the ext
if len(self.checkpoint_files) > self.history:
# remove oldest checkpoint file when max count have been saved
fn = self.checkpoint_files.popleft()
try:
os.remove(fn)
logger.info("removed old checkpoint %s" % fn)
except OSError:
logger.warn("Could not delete old checkpoint file %s" % fn)
path_split = os.path.splitext(self.save_path)
save_path = "%s_%d%s" % (path_split[0], epoch, path_split[1])
# add the current file to the deque
self.checkpoint_files.append(save_path)
save_obj(model.serialize(keep_states=True), save_path)
# maintain a symlink pointing to the latest model params
try:
if os.path.islink(self.save_path):
os.remove(self.save_path)
os.symlink(os.path.split(save_path)[-1], self.save_path)
except OSError:
logger.warn("Could not create latest model symlink %s -> %s" % (self.save_path, save_path))
开发者ID:yapjiaqing,项目名称:neon,代码行数:27,代码来源:callbacks.py
示例2: on_epoch_end
def on_epoch_end(self, callback_data, model, epoch):
_eil = self._get_cached_epoch_loss(callback_data, model, epoch, "loss")
if _eil:
if _eil["cost"] < self.best_cost or self.best_cost is None:
# TODO: switch this to a general seralization op
save_obj(model.serialize(keep_states=True), self.best_path)
self.best_cost = _eil["cost"]
开发者ID:yapjiaqing,项目名称:neon,代码行数:7,代码来源:callbacks.py
示例3: serialize
def serialize(self, fn=None, keep_states=True):
"""
Creates a dictionary storing the layer parameters and epochs complete.
Arguments:
fn (str): file to save pkl formatted model dictionary
keep_states (bool): Whether to save optimizer states.
Returns:
dict: Model data including layer parameters and epochs complete.
"""
# get the model dict with the weights
pdict = self.get_description(get_weights=True, keep_states=keep_states)
pdict['epoch_index'] = self.epoch_index + 1
if self.initialized:
if not hasattr(self.layers, 'decoder'):
pdict['train_input_shape'] = self.layers.in_shape
else:
# serialize shapes both for encoder and decoder
pdict['train_input_shape'] = (self.layers.encoder.in_shape +
self.layers.decoder.in_shape)
if fn is not None:
save_obj(pdict, fn)
return
return pdict
开发者ID:rlugojr,项目名称:neon,代码行数:26,代码来源:model.py
示例4: save_params
def save_params(self, param_path, keep_states=True):
"""
Serializes and saves model parameters to the path specified.
Arguments:
param_path (str): File to write serialized parameter dict to.
keep_states (bool): Whether to save optimizer states too.
Defaults to True.
"""
save_obj(self.serialize(keep_states), param_path)
开发者ID:bin2000,项目名称:neon,代码行数:10,代码来源:model.py
示例5: on_epoch_end
def on_epoch_end(self, epoch):
if 'cost/validation' in self.callback_data:
val_freq = self.callback_data['cost/validation'].attrs['epoch_freq']
if (epoch + 1) % val_freq == 0:
validation_cost = self.callback_data['cost/validation'][epoch/val_freq]
if validation_cost < self.best_cost or self.best_cost is None:
save_obj(self.model.serialize(keep_states=True), self.best_path)
self.best_cost = validation_cost
开发者ID:rupertsmall,项目名称:neon,代码行数:10,代码来源:callbacks.py
示例6: save_meta
def save_meta(self):
save_obj({'ntrain': self.ntrain,
'nval': self.nval,
'train_start': self.train_start,
'val_start': self.val_start,
'macro_size': self.macro_size,
'batch_prefix': self.batch_prefix,
'global_mean': self.global_mean,
'label_dict': self.label_dict,
'label_names': self.label_names,
'val_nrec': self.val_nrec,
'train_nrec': self.train_nrec,
'img_size': self.target_size,
'nclass': self.nclass}, self.meta_file)
开发者ID:GerritKlaschke,项目名称:neon,代码行数:14,代码来源:batch_writer.py
示例7: on_sigint_catch
def on_sigint_catch(self, epoch, minibatch):
"""
Callback to handle SIGINT events
Arguments:
epoch (int): index of current epoch
minibatch (int): index of minibatch that is ending
"""
# restore the orignal handler
signal.signal(signal.SIGINT, signal.SIG_DFL)
# save the model
if self.save_path is not None:
save_obj(self.model().serialize(keep_states=True), self.save_path)
raise KeyboardInterrupt("Checkpoint file saved to {0}".format(self.save_path))
else:
raise KeyboardInterrupt
开发者ID:yapjiaqing,项目名称:neon,代码行数:17,代码来源:callbacks.py
示例8: serialize
def serialize(self, fn=None, keep_states=True):
"""
Creates a dictionary storing the layer parameters and epochs complete.
Arguments:
fn (str): file to save pkl formatted model dictionary
keep_states (bool): Whether to save optimizer states.
Returns:
dict: Model data including layer parameters and epochs complete.
"""
# get the model dict with the weights
pdict = self.get_description(get_weights=True, keep_states=keep_states)
pdict['epoch_index'] = self.epoch_index + 1
if fn is not None:
save_obj(pdict, fn)
return
return pdict
开发者ID:maony,项目名称:neon,代码行数:19,代码来源:model.py
示例9: save_meta
def save_meta(self):
save_obj(
{
"ntrain": self.ntrain,
"nval": self.nval,
"train_start": self.train_start,
"val_start": self.val_start,
"macro_size": self.macro_size,
"batch_prefix": self.batch_prefix,
"global_mean": self.global_mean,
"label_dict": self.label_dict,
"label_names": self.label_names,
"val_nrec": self.val_nrec,
"train_nrec": self.train_nrec,
"img_size": self.target_size,
"nclass": self.nclass,
},
self.meta_file,
)
开发者ID:hgl888,项目名称:neon,代码行数:19,代码来源:batch_writer.py
示例10: save_history
def save_history(self, epoch):
# if history > 1, this function will save the last N checkpoints
# where N is equal to self.history. The files will have the form
# of save_path with the epoch added to the filename before the ext
if len(self.checkpoint_files) > self.history:
# remove oldest checkpoint file when max count have been saved
fn = self.checkpoint_files.popleft()
try:
os.remove(fn)
logger.info('removed old checkpoint %s' % fn)
except OSError:
logger.warn('Could not delete old checkpoint file %s' % fn)
path_split = os.path.splitext(self.save_path)
save_path = '%s_%d%s' % (path_split[0], epoch, path_split[1])
# add the current file to the deque
self.checkpoint_files.append(save_path)
save_obj(self.model.serialize(keep_states=True), save_path)
开发者ID:rupertsmall,项目名称:neon,代码行数:19,代码来源:callbacks.py
示例11: get_w2v_vocab
def get_w2v_vocab(fname, max_vocab_size, cache=True):
"""
Get ordered dict of vocab from google word2vec
"""
if cache:
cache_fname = fname.split('.')[0] + ".vocab"
if os.path.isfile(cache_fname):
vocab, vocab_size = load_obj(cache_fname)
neon_logger.display("Word2Vec vocab cached, size is: {}".format(vocab_size))
return vocab, vocab_size
with open(fname, 'rb') as f:
header = f.readline()
vocab_size, embed_dim = map(int, header.split())
binary_len = np.dtype('float32').itemsize * embed_dim
neon_logger.display("Word2Vec vocab size is: {}".format(vocab_size))
vocab_size = min(max_vocab_size, vocab_size)
neon_logger.display("Reducing vocab size to: {}".format(vocab_size))
vocab = OrderedDict()
for i, line in enumerate(range(vocab_size)):
word = []
while True:
ch = f.read(1)
if ch == b' ':
word = (b''.join(word)).decode('utf-8')
break
if ch != b'\n':
word.append(ch)
f.read(binary_len)
vocab[word] = i
if cache:
save_obj((vocab, vocab_size), cache_fname)
return vocab, vocab_size
开发者ID:rlugojr,项目名称:neon,代码行数:39,代码来源:util.py
示例12: PolySchedule
lr_sched = PolySchedule(total_epochs=10, power=0.5)
opt_gdm = GradientDescentMomentum(0.01, 0.9, wdecay=0.0002, schedule=lr_sched)
opt_biases = GradientDescentMomentum(0.02, 0.9, schedule=lr_sched)
opt = MultiOptimizer({'default': opt_gdm, 'Bias': opt_biases})
if not args.resume:
# fit the model for 3 epochs
model.fit(train, optimizer=opt, num_epochs=3, cost=cost, callbacks=callbacks)
train.reset()
# get 1 image
for im, l in train:
break
train.exit_batch_provider()
save_obj((im.get(), l.get()), 'im1.pkl')
im_save = im.get().copy()
if args.resume:
(im2, l2) = load_obj('im1.pkl')
im.set(im2)
l.set(l2)
# run fprop and bprop on this minibatch save the results
out_fprop = model.fprop(im)
out_fprop_save = [x.get() for x in out_fprop]
im.set(im_save)
out_fprop = model.fprop(im)
out_fprop_save2 = [x.get() for x in out_fprop]
for x, y in zip(out_fprop_save, out_fprop_save2):
assert np.max(np.abs(x - y)) == 0.0, '2 fprop iterations do not match'
开发者ID:JediKoder,项目名称:neon,代码行数:30,代码来源:inception.py
示例13: IOError
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("cache_file", help="path to data cache file")
args = parser.parse_args()
cache_file = args.cache_file
# check for RW access to file
assert os.path.exists(cache_file), "file does not exist %s" % cache_file
if not os.access(os.path.abspath(cache_file), os.R_OK | os.W_OK):
raise IOError("Need to add read and/or write permissions on file %s" % cache_file)
dc = load_obj(cache_file)
if "global_mean" not in dc or "img_size" not in dc:
raise ValueError("data cache file missing global_mean key")
sz = dc["img_size"]
gm = dc["global_mean"]
if len(gm.shape) != 2 or (gm.shape[0] != sz * sz * 3 or gm.shape[1] != 1):
raise ValueError("global mean shape {} does not match format expected".format(gm.shape))
# Collapse the full tensor mean into channel means and correct the order (RGB <-> BGR)
dc["global_mean"] = np.mean(gm.reshape(3, -1), axis=1).reshape(3, 1)[::-1]
save_obj(dc, cache_file)
neon_logger.display("%s updated to new format" % cache_file)
开发者ID:Jokeren,项目名称:neon,代码行数:29,代码来源:update_dataset_cache.py
示例14: load_data
#.........这里部分代码省略.........
neon_logger.display("open existing vocab file: {}".format(vocab_file_name))
vocab, rev_vocab, word_count = load_obj(vocab_file_name)
else:
neon_logger.display("Building vocab file")
# build vocab
word_count = defaultdict(int)
for sent in all_sent:
sent_words = tokenize(sent)
if len(sent_words) > max_len_w or len(sent_words) == 0:
continue
for word in sent_words:
word_count[word] += 1
# sort the word_count , re-assign ids by its frequency. Useful for downstream tasks
# only done for train vocab
vocab_sorted = sorted(word_count.items(), key=lambda kv: kv[1], reverse=True)
vocab = OrderedDict()
# get word count as array in same ordering as vocab (but with maximum length)
word_count_ = np.zeros((len(word_count), ), dtype=np.int64)
for i, t in enumerate(list(zip(*vocab_sorted))[0][:max_vocab_size]):
word_count_[i] = word_count[t]
vocab[t] = i
word_count = word_count_
# generate the reverse vocab
rev_vocab = dict((wrd_id, wrd) for wrd, wrd_id in vocab.items())
neon_logger.display("vocabulary from {} is saved into {}".format(path, vocab_file_name))
save_obj((vocab, rev_vocab, word_count), vocab_file_name)
vocab_size = len(vocab)
neon_logger.display("\nVocab size from the dataset is: {}".format(vocab_size))
neon_logger.display("\nProcessing and saving training data into {}".format(h5_file_name))
# now process and save the train/valid data
h5f = h5py.File(h5_file_name, 'w', libver='latest')
shape, maxshape = (len(train_sent),), (None)
dt = np.dtype([('text', h5py.special_dtype(vlen=str)),
('num_words', np.uint16)])
report_text_train = h5f.create_dataset('report_train', shape=shape,
maxshape=maxshape, dtype=dt,
compression='gzip')
report_train = h5f.create_dataset('train', shape=shape, maxshape=maxshape,
dtype=h5py.special_dtype(vlen=np.int32),
compression='gzip')
# map text to integers
wdata = np.zeros((1, ), dtype=dt)
ntrain = 0
for sent in train_sent:
text_int = [-1 if t not in vocab else vocab[t] for t in tokenize(sent)]
# enforce maximum sentence length
if len(text_int) > max_len_w or len(text_int) == 0:
continue
report_train[ntrain] = text_int
wdata['text'] = clean_string(sent)
wdata['num_words'] = len(text_int)
开发者ID:NervanaSystems,项目名称:neon,代码行数:67,代码来源:data_loader.py
示例15: int
(im_shape, im_scale, gt_boxes, gt_classes,
num_gt_boxes, difficult) = valid_set.get_metadata_buffers()
num_gt_boxes = int(num_gt_boxes.get())
im_scale = float(im_scale.get())
# retrieve region proposals generated by the model
(proposals, num_proposals) = proposalLayer.get_proposals()
# convert outputs to bounding boxes
boxes = faster_rcnn.get_bboxes(outputs, proposals, num_proposals, num_classes,
im_shape.get(), im_scale, max_per_image, thresh, nms_thresh)
all_boxes[mb_idx] = boxes
# retrieve gt boxes
# we add a extra column to track detections during the AP calculation
detected = np.array([False] * num_gt_boxes)
gt_boxes = np.hstack([gt_boxes.get()[:num_gt_boxes] / im_scale,
gt_classes.get()[:num_gt_boxes],
difficult.get()[:num_gt_boxes], detected[:, np.newaxis]])
all_gt_boxes[mb_idx] = gt_boxes
neon_logger.display('Evaluating detections')
avg_precision = voc_eval(all_boxes, all_gt_boxes, valid_set.CLASSES, use_07_metric=True)
if args.output is not None:
neon_logger.display('Saving inference results to {}'.format(args.output))
save_obj([all_boxes, avg_precision], args.output)
开发者ID:NervanaSystems,项目名称:neon,代码行数:30,代码来源:inference.py
示例16: __init__
#.........这里部分代码省略.........
# how many ROIs to use to train frcnn
self.frcn_rois_per_img = frcn_rois_per_img if frcn_rois_per_img \
else self.FRCNN_ROI_PER_IMAGE
assert self.img_per_batch == 1, "Only a minibatch of 1 is supported."
self.num_classes = len(self.CLASSES)
self._class_to_index = dict(list(zip(self.CLASSES, list(range(self.num_classes)))))
# shape of the final conv layer
if conv_size:
self._conv_size = conv_size
else:
self._conv_size = int(np.floor(self.MAX_SIZE * self.SCALE))
self._feat_stride = 1 / float(self.SCALE)
self._num_scales = len(self.SCALES) * len(self.RATIOS)
self._total_anchors = self._conv_size * self._conv_size * self._num_scales
self.shuffle = shuffle
self.deterministic = deterministic
self.add_flipped = add_flipped
# load the configure the dataset paths
self.config = self.load_data()
# annotation metadata
self._annotation_file_ext = '.xml'
self._annotation_obj_tag = 'object'
self._annotation_class_tag = 'name'
self._annotation_xmin_tag = 'xmin'
self._annotation_xmax_tag = 'xmax'
self._annotation_ymin_tag = 'ymin'
self._annotation_ymax_tag = 'ymax'
# self.rois_per_batch is 128 (2*64) ROIs
# But the image path batch size is self.img_per_batch
# need to control the batch size here
assert self.img_per_batch is 1, "Only a batch size of 1 image is supported"
neon_logger.display("Backend batchsize is changed to be {} "
"from Object Localization dataset".format(
self.img_per_batch))
self.be.bsz = self.img_per_batch
# 0. allocate buffers
self.allocate()
if not self.mock_db:
# 1. read image index file
assert os.path.exists(self.config['image_path']), \
'Image index file does not exist: {}'.format(self.config['image_path'])
with open(self.config['index_path']) as f:
self.image_index = [x.strip() for x in f.readlines()]
num_images = len(self.image_index)
self.num_image_entries = num_images * 2 if self.add_flipped else num_images
self.ndata = self.num_image_entries * self.rois_per_img
else:
self.num_image_entries = 1
self.ndata = self.num_image_entries * self.rois_per_img
assert (subset_pct > 0 and subset_pct <= 100), ('subset_pct must be between 0 and 100')
if n_mb is not None:
self.nbatches = n_mb
else:
self.nbatches = int(self.num_image_entries / self.img_per_batch * subset_pct / 100)
self.cache_file = self.config['cache_path']
if os.path.exists(self.cache_file) and not rebuild_cache and not self.mock_db:
self.roi_db = load_obj(self.cache_file)
neon_logger.display('ROI dataset loaded from file {}'.format(self.cache_file))
elif not self.mock_db:
# 2. read object Annotations (XML)
roi_db = self.load_roi_groundtruth()
if(self.add_flipped):
roi_db = self.add_flipped_db(roi_db)
# 3. construct acnhor targets
self.roi_db = self.add_anchors(roi_db)
if NORMALIZE_BBOX_TARGETS:
# 4. normalize bbox targets by class
self.roi_db = self.normalize_bbox_targets(self.roi_db)
save_obj(self.roi_db, self.cache_file)
neon_logger.display('wrote ROI dataset to {}'.format(self.cache_file))
else:
assert self.mock_db is not None
roi_db = [self.mock_db]
self.roi_db = self.add_anchors(roi_db)
# 4. map anchors back to full canvas.
# This is neccessary because the network outputs reflect the full canvas.
# We cache the files in the unmapped state (above) to save memory.
self.roi_db = unmap(self.roi_db)
开发者ID:Jokeren,项目名称:neon,代码行数:101,代码来源:objectlocalization.py
示例17: test_model_serialize
def test_model_serialize(backend_default, data):
(X_train, y_train), (X_test, y_test), nclass = load_mnist(path=data)
train_set = DataIterator(
[X_train, X_train], y_train, nclass=nclass, lshape=(1, 28, 28))
init_norm = Gaussian(loc=0.0, scale=0.01)
# initialize model
path1 = Sequential([Conv((5, 5, 16), init=init_norm, bias=Constant(0), activation=Rectlin()),
Pooling(2),
Affine(nout=20, init=init_norm, bias=init_norm, activation=Rectlin())])
path2 = Sequential([Affine(nout=100, init=init_norm, bias=Constant(0), activation=Rectlin()),
Dropout(keep=0.5),
Affine(nout=20, init=init_norm, bias=init_norm, activation=Rectlin())])
layers = [MergeMultistream(layers=[path1, path2], merge="stack"),
Affine(nout=20, init=init_norm, batch_norm=True, activation=Rectlin()),
Affine(nout=10, init=init_norm, activation=Logistic(shortcut=True))]
tmp_save = 'test_model_serialize_tmp_save.pickle'
mlp = Model(layers=layers)
mlp.optimizer = GradientDescentMomentum(learning_rate=0.1, momentum_coef=0.9)
mlp.cost = GeneralizedCost(costfunc=CrossEntropyBinary())
mlp.initialize(train_set, cost=mlp.cost)
n_test = 3
num_epochs = 3
# Train model for num_epochs and n_test batches
for epoch in range(num_epochs):
for i, (x, t) in enumerate(train_set):
x = mlp.fprop(x)
delta = mlp.cost.get_errors(x, t)
mlp.bprop(delta)
mlp.optimizer.optimize(mlp.layers_to_optimize, epoch=epoch)
if i > n_test:
break
# Get expected outputs of n_test batches and states of all layers
outputs_exp = []
pdicts_exp = [l.get_params_serialize() for l in mlp.layers_to_optimize]
for i, (x, t) in enumerate(train_set):
outputs_exp.append(mlp.fprop(x, inference=True))
if i > n_test:
break
# Serialize model
save_obj(mlp.serialize(keep_states=True), tmp_save)
# Load model
mlp = Model(layers=layers)
mlp.load_weights(tmp_save)
outputs = []
pdicts = [l.get_params_serialize() for l in mlp.layers_to_optimize]
for i, (x, t) in enumerate(train_set):
outputs.append(mlp.fprop(x, inference=True))
if i > n_test:
break
# Check outputs, states, and params are the same
for output, output_exp in zip(outputs, outputs_exp):
assert np.allclose(output.get(), output_exp.get())
for pd, pd_exp in zip(pdicts, pdicts_exp):
for s, s_e in zip(pd['states'], pd_exp['states']):
if isinstance(s, list): # this is the batch norm case
for _s, _s_e in zip(s, s_e):
assert np.allclose(_s, _s_e)
else:
assert np.allclose(s, s_e)
for p, p_e in zip(pd['params'], pd_exp['params']):
assert type(p) == type(p_e)
if isinstance(p, list): # this is the batch norm case
for _p, _p_e in zip(p, p_e):
assert np.allclose(_p, _p_e)
elif isinstance(p, np.ndarray):
assert np.allclose(p, p_e)
else:
assert p == p_e
os.remove(tmp_save)
开发者ID:GerritKlaschke,项目名称:neon,代码行数:80,代码来源:test_model.py
示例18: DataIterator
X_test, y_test, cluster)
spec_out = nout
spec_set = DataIterator(
X_spec, y_spec, nclass=spec_out, lshape=(3, 32, 32))
spec_test = DataIterator(
X_spec_test, y_spec_test, nclass=spec_out, lshape=(3, 32, 32))
# Train the specialist
specialist, opt, cost = spec_net(nout=spec_out, archive_path=gene_path)
callbacks = Callbacks(specialist, spec_set, args, eval_set=spec_test)
callbacks.add_early_stop_callback(early_stop)
callbacks.add_save_best_state_callback(path)
specialist.fit(spec_set, optimizer=opt,
num_epochs=specialist.epoch_index + num_epochs, cost=cost, callbacks=callbacks)
# Print results
print 'Specialist Train misclassification error: ', specialist.eval(spec_set, metric=Misclassification())
print 'Specialist Test misclassification error: ', specialist.eval(spec_test, metric=Misclassification())
print 'Generalist Train misclassification error: ', generalist.eval(spec_set, metric=Misclassification())
print 'Generalist Test misclassification error: ', generalist.eval(spec_test, metric=Misclassification())
# specialists.append(specialist)
save_obj(specialist.serialize(), path)
except:
path = confusion_matrix_name + '_' + clustering_name + '_' + str(num_clusters) + 'clusters/'
print 'Failed for ', path
failed.append(path)
for f in failed:
print f
开发者ID:seba-1511,项目名称:specialists,代码行数:29,代码来源:train_all_specs.py
示例19: save_weights
def save_weights(self, save_path):
save_obj(self.model.serialize(keep_states = True), save_path)
开发者ID:rockhowse,项目名称:simple_dqn,代码行数:2,代码来源:deepqnetwork.py
示例20: train
def train(self, minibatch, epoch = 0):
# expand components of minibatch
prestates, steers, speeds, rewards, poststates, terminals = minibatch
assert len(prestates.shape) == 2
assert len(poststates.shape) == 2
assert len(steers.shape) == 1
assert len(speeds.shape) == 1
assert len(rewards.shape) == 1
assert len(terminals.shape) == 1
assert prestates.shape == poststates.shape
assert prestates.shape[0] == steers.shape[0] == speeds.shape[0] == rewards.shape[0] == poststates.shape[0] == terminals.shape[0]
if self.target_steps and self.train_iterations % self.target_steps == 0:
# HACK: serialize network to disk and read it back to clone
filename = self.save_weights_prefix + "_target.pkl"
save_obj(self.model.serialize(keep_states = False), filename)
self.target_model.load_weights(filename)
# feed-forward pass for poststates to get Q-values
self._setInput(poststates)
postq = self.target_model.fprop(self.input, inference = True)
assert postq.shape == (self.num_actions, self.batch_size)
# calculate max Q-value for each poststate
postq = postq.asnumpyarray()
maxsteerq = np.max(postq[:self.num_steers,:], axis=0)
assert maxsteerq.shape == (self.batch_size,), "size: %s" % str(maxsteerq.shape)
maxspeedq = np.max(postq[-self.num_speeds:,:], axis=0)
assert maxspeedq.shape == (self.batch_size,)
# feed-forward pass for prestates
self._setInput(prestates)
preq = self.model.fprop(self.input, inference = False)
assert preq.shape == (self.num_actions, self.batch_size)
# make copy of prestate Q-values as targets
# HACK: copy() was needed to make it work on CPU
targets = preq.asnumpyarray().copy()
# update Q-value targets for actions taken
for i, (steer, speed) in enumerate(zip(steers, speeds)):
if terminals[i]:
targets[steer, i] = float(rewards[i])
targets[self.num_steers + speed, i] = float(rewards[i])
else:
targets[steer, i] = float(rewards[i]) + self.discount_rate * maxsteerq[i]
targets[self.num_steers + speed, i] = float(rewards[i]) + self.discount_rate * maxspeedq[i]
# copy targets to GPU memory
self.targets.set(targets)
# calculate errors
deltas = self.cost.get_errors(preq, self.targets)
assert deltas.shape == (self.num_actions, self.batch_size)
#assert np.count_nonzero(deltas.asnumpyarray()) == 2 * self.batch_size, str(np.count_nonzero(deltas.asnumpyarray()))
# calculate cost, just in case
cost = self.cost.get_cost(preq, self.targets)
assert cost.shape == (1,1)
#print "cost:", cost.asnumpyarray()
# clip errors
if self.clip_error:
self.be.clip(deltas, -self.clip_error, self.clip_error, out = deltas)
# perform back-propagation of gradients
self.model.bprop(deltas)
# perform optimization
self.optimizer.optimize(self.model.layers_to_optimize, epoch)
'''
if np.any(rewards < 0):
preqq = preq.asnumpyarray().copy()
self._setInput(prestates)
qvalues = self.model.fprop(self.input, inference = True).asnumpyarray().copy()
indexes = rewards < 0
print "indexes:", indexes
print "preq:", preqq[:, indexes].T
print "preq':", qvalues[:, indexes].T
print "diff:", (qvalues[:, indexes]-preqq[:, indexes]).T
print "steers:", steers[indexes]
print "speeds:", speeds[indexes]
print "rewards:", rewards[indexes]
print "terminals:", terminals[indexes]
print "preq[0]:", preqq[:, 0]
print "preq[0]':", qvalues[:, 0]
print "diff:", qvalues[:, 0] - preqq[:, 0]
print "deltas:", deltas.asnumpyarray()[:, indexes].T
raw_input("Press Enter to continue...")
'''
# increase number of weight updates (needed for target clone interval)
self.train_iterations += 1
开发者ID:tambetm,项目名称:botmobile,代码行数:94,代码来源:deepqnetwork.py
注:本文中的neon.util.persist.save_obj函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论