本文整理汇总了Python中pylearn2.utils.serial.load函数的典型用法代码示例。如果您正苦于以下问题:Python load函数的具体用法?Python load怎么用?Python load使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了load函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: get_processed_dataset
def get_processed_dataset():
train_path = 'pp_cifar10_train.pkl'
test_path = 'pp_cifar10_test.pkl'
if os.path.exists(train_path) and os.path.exists(test_path):
print 'loading preprocessed data'
trainset = serial.load(train_path)
testset = serial.load(test_path)
else:
print 'loading raw data...'
trainset = cifar10.CIFAR10(which_set="train")
testset = cifar10.CIFAR10(which_set="test")
pipeline = preprocessing.Pipeline()
pipeline.items.append(preprocessing.ExtractPatchesWithPosition(patch_shape=patch_shape, patches_per_image=patches_per_image))
pipeline.items.append(preprocessing.GlobalContrastNormalization(sqrt_bias=10., use_std=True))
pipeline.items.append(preprocessing.PCA(num_components = num_components, keep_var_fraction = keep_var_fraction))
pipeline.items.append(preprocessing.ExtractPatchPairs(patches_per_image = patches_per_image, num_images = train_size, input_width = input_width))
trainset.apply_preprocessor(preprocessor=pipeline, can_fit=True)
# the pkl-ing is having issues, the dataset is maybe too big.
serial.save('pp_cifar10_train.pkl', trainset)
serial.save('pp_cifar10_test.pkl', testset)
# this path will be used for visualizing weights after training is done
trainset.yaml_src = '!pkl: "%s"' % train_path
testset.yaml_src = '!pkl: "%s"' % test_path
return trainset, testset
开发者ID:capybaralet,项目名称:current,代码行数:32,代码来源:testrun.py
示例2: __init__
def __init__(self,model):
self.filename = model['filename']
try:
try:
self.model = serial.load(self.filename)
except:
self.filename = os.environ['MMDAEdaes'] + self.filename
self.model = serial.load(self.filename)
except Exception as e:
print("error loading {}:".format(self.filename))
print(e)
return False
if not hasattr(self.model, 'sequence'):
self.__extract_sequence()
if not hasattr(self.model, 'mean') or not hasattr(self.model, 'std') or not hasattr(self.model, 'normalise'):
self.__calc_mean_std()
if not hasattr(self.model, 'function'):
self.__create_function()
if not hasattr(self.model, 'mmdae_type'):
self.__determine_type()
self.dtype = theano.config.floatX
self.__clean()
开发者ID:jamessergeant,项目名称:MMDAEforControl,代码行数:29,代码来源:node_classes.py
示例3: main
def main():
args = check_argv()
model = serial.load(args.model_fn)
print "Constructing model"
if model.__class__ == pylearn2.models.mlp.MLP:
# This is a correspondence model
print "Loaded:", args.model_fn
if args.use_layer is not None:
use_layer = args.use_layer
print "Using encoding from layer", use_layer, "out of", len(model.layers)
else:
print "Using last layer out of", len(model.layers)
use_layer = -1
dAEs = [l.layer_content for l in model.layers[:use_layer]]
else :
# This is a normal stacked dAE: get the other layers from filename
assert args.use_layer is None, "layer already specified in filename"
model_dir, basename = path.split(args.model_fn)
use_layer = int(basename.split(".")[-2].replace("layer", ""))
# if use_layer != 0:
# This is not single-layer model
dAEs = []
for layer in range(use_layer + 1):
model_fn = path.join(
model_dir, ".".join(basename.split(".")[:-2]) + ".layer" + str(layer) + ".pkl"
)
print "Loading:", model_fn
dAEs.append(serial.load(model_fn))
model = pylearn2.models.autoencoder.DeepComposedAutoencoder(dAEs)
input_dataset = dict(serial.load(args.input_fn))
# Symbolic matrix of items to encode
x = T.dmatrix('x')
encoder = model.encode(x)
encode_func = function([x], encoder)
# Perform encoding
print "Performing encoding"
result = {}
for (label, features) in input_dataset.items():
result[label] = encode_func(features)
# Write encoded output
input_basename = path.splitext(path.split(args.input_fn)[-1])[0]
model_dir, model_basename = path.split(args.model_fn)
model_basename = path.splitext(model_basename)[0]
model_basename = path.split(model_dir)[-1] + "." + model_basename
encoded_fn = path.join(
output_dir,
"encoded." + input_basename + "." + model_basename +
(".layer" + str(args.use_layer) if args.use_layer is not None else "") + ".npz"
)
print "Writing encoding:", encoded_fn
np.savez(encoded_fn, **result)
开发者ID:getupyang,项目名称:speech_correspondence,代码行数:60,代码来源:encode.py
示例4: load_data
def load_data(which="original", center=False, scale=False):
if which == "original":
path = "/data/lisa/data/faces/GoogleDataset/Clean/latest.pkl"
data = serial.load(path)
data_x = data[0]
data_y = data[1]
assert len(data_x) == len(data_y)
if center:
data_x -= 0.5
elif which == "kaggle":
path = "/data/lisa/data/faces/EmotiW/preproc/samira/KGL-AFEW/"
data_x = serial.load(path + "train_kaggle_x.npy")
data_y = serial.load(path + "train_kaggle_y.npy")
assert len(data_x) == len(data_y)
if scale:
data_x /= 255.0
if center:
data_x -= 0.5
elif center:
data_x -= 127.5
one_hot = np.zeros((data_y.shape[0], 7), dtype="float32")
for i in xrange(data_y.shape[0]):
one_hot[i, data_y[i]] = 1.0
data_y = one_hot
return data_x.reshape(data_x.shape[0], 48 * 48).astype("float32"), data_y
开发者ID:YangXS,项目名称:lisa_emotiw,代码行数:27,代码来源:googleTFD.py
示例5: train_layer4
def train_layer4(supervised=True):
global unsup_dataset, sup_dataset
# Process unsupervised layer 4
unsup_dataset = TransformerDataset(raw=unsup_dataset, transformer=serial.load(layer3_unsup_model))
model = DenoisingAutoencoder(BinomialCorruptor(corruption_level=0.002), nvis=nhid3, nhid=nhid4, act_enc='tanh', act_dec=None, irange=0.5)
training_alg = SGD(cost=MeanSquaredReconstructionError(), learning_rate=1e-4, batch_size= batch_size, monitoring_dataset=unsup_dataset, termination_criterion=EpochCounter(max_epochs=max_epochs))
extensions = [MonitorBasedLRAdjuster()]
experiment = Train(dataset=unsup_dataset, model=model, algorithm=training_alg, save_path=layer4_unsup_model, save_freq=50, allow_overwrite=True, extensions=extensions)
experiment.main_loop()
if supervised:
# Process supervised layer 4
layers = [PretrainedLayer(layer_name='h1', layer_content=serial.load(layer1_unsup_model), freeze_params=False),
PretrainedLayer(layer_name='h2', layer_content=serial.load(layer2_unsup_model), freeze_params=False),
PretrainedLayer(layer_name='h3', layer_content=serial.load(layer3_unsup_model), freeze_params=False),
PretrainedLayer(layer_name='h4', layer_content=serial.load(layer4_unsup_model), freeze_params=False),
Softmax(n_classes=class_number, layer_name='y', irange=0.5)]
model = MLP(layers=layers, batch_size=sup_dataset.y.shape[0], nvis=nvis, layer_name=None)
training_alg = SGD(learning_rate=1e-3, monitoring_dataset=sup_dataset, termination_criterion=EpochCounter(max_epochs=max_epochs_mlp))
experiment = Train(dataset=sup_dataset, model=model, algorithm=training_alg, save_path=layer4_sup_model, save_freq=50, allow_overwrite=True, extensions=extensions)
experiment.main_loop()
serial.save(layer1_unsup_model, model.layers[0].layer_content)
serial.save(layer2_unsup_model, model.layers[1].layer_content)
serial.save(layer3_unsup_model, model.layers[2].layer_content)
serial.save(layer4_unsup_model, model.layers[3].layer_content)
开发者ID:lluiscastrejonsubira,项目名称:Network-Oracle,代码行数:26,代码来源:trainer_v2.py
示例6: _load_data
def _load_data(self, which_set, context_len, data_mode):
if data_mode not in ['words', 'chars']:
raise ValueError("Only 'words' and 'chars' are possible values"
"for data_mode, not %s" % (data_mode,))
path = "${PYLEARN2_DATA_PATH}/PennTreebankCorpus/"
npz_data = serial.load(path + "penntree_char_and_word.npz")
if which_set == 'train':
self._raw_data = npz_data['train_' + data_mode]
elif which_set == 'valid':
self._raw_data = npz_data['valid_' + data_mode]
elif which_set == 'test':
self._raw_data = npz_data['test_' + data_mode]
else:
raise ValueError("Dataset must be one of 'train', 'valid' "
"or 'test'")
# Use word.lower() because the dictionary contains a single word
# that is capitalized for some reason: N
npz_data = serial.load(path + "dictionaries.npz")
self._vocabulary = dict((word.lower(), word_index) for word_index, word
in enumerate(npz_data['unique_' + data_mode]))
if data_mode == 'words':
self._unknown_index = 591
self._max_labels = 10000
else:
self._unknown_index = 50
self._max_labels = 51
self._is_case_sensitive = False
开发者ID:123fengye741,项目名称:pylearn2,代码行数:31,代码来源:penntree.py
示例7: __call__
def __call__(self):
print 'loading model'
if self.num_filters == 1600:
d = serial.load('${USERDIR}/galatea/s3c/sc_vq_demo/omp1.mat')
elif self.num_filters == 800:
d = serial.load('/RQexec/goodfell/omp1_800.mat')
else:
assert False
self.W = sharedX(d['dictionary'].T)
self.size = int(np.sqrt(self.W.get_value().shape[0]/3))
if self.chunk_size is not None:
dataset_family = self.dataset_family
which_set = self.which_set
dataset_descriptor = self.dataset_family[which_set][size]
num_examples = dataset_descriptor.num_examples
assert num_examples % self.chunk_size == 0
self.chunk_id = 0
for i in xrange(0,num_examples, self.chunk_size):
self.restrict = (i, i + self.chunk_size)
self._execute()
self.chunk_id += 1
else:
self._execute()
开发者ID:cc13ny,项目名称:galatea,代码行数:31,代码来源:extract_features_omp.py
示例8: load_cifar10
def load_cifar10():
from pylearn2.utils import serial
from pylearn2.datasets.zca_dataset import ZCA_Dataset
# from pylearn2.datasets.cifar10 import CIFAR10
import theano
def rotate_and_convert_grayscale(img):
reshaped = img.reshape(32, 32, 3, order="F")
rotated = np.rot90(reshaped, k=3)
grayscaled = np.dot(rotated[:, :, :3], [0.299, 0.587, 0.144])
return grayscaled
def transform(img_set):
result = []
# Convert all images to grayscale and flatten the shape
for img in img_set:
# result.append(rotate_and_convert_grayscale(img).ravel())
result.append(img.ravel())
return np.array(result)
# train_set = CIFAR10(which_set='train', start=0, stop=45000)
# valid_set = CIFAR10(which_set='train', start=45000, stop=50000)
# test_set = CIFAR10(which_set='test')
data_path = os.getenv("PYLEARN2_DATA_PATH")
whitened_path = os.path.join(data_path, "cifar10_cpu", "pylearn2_gcn_whitened")
preprocessed_train_dataset = serial.load(os.path.join(whitened_path, "train.pkl"))
preprocessed_test_dataset = serial.load(os.path.join(whitened_path, "test.pkl"))
preprocesssor = serial.load(os.path.join(whitened_path, "preprocessor.pkl"))
train_set = ZCA_Dataset(preprocessed_train_dataset, preprocesssor, start=0, stop=45000)
valid_set = ZCA_Dataset(preprocessed_train_dataset, preprocesssor, start=45000, stop=50000)
test_set = ZCA_Dataset(preprocessed_test_dataset, preprocesssor)
# Convert the images to grayscale and flatten them
train_set.X = transform(train_set.X)
valid_set.X = transform(valid_set.X)
test_set.X = transform(test_set.X)
def shared_y_cast(y):
shared_y = theano.shared(np.asarray(y, dtype=theano.config.floatX), borrow=True)
return T.cast(shared_y, "int32")
train_set_tuple = (
theano.shared(np.array(train_set.X, dtype=theano.config.floatX), borrow=True),
shared_y_cast(train_set.y.ravel()),
)
valid_set_tuple = (
theano.shared(np.array(valid_set.X, dtype=theano.config.floatX), borrow=True),
shared_y_cast(valid_set.y.ravel()),
)
test_set_tuple = (
theano.shared(np.array(test_set.X, dtype=theano.config.floatX), borrow=True),
shared_y_cast(test_set.y.ravel()),
)
return [train_set_tuple, valid_set_tuple, test_set_tuple]
开发者ID:SatwantKumar,项目名称:incremental-kd,代码行数:60,代码来源:logistic_sgd.py
示例9: get_all_datasets
def get_all_datasets(tot, preprocessors):
for ii, preprocessor in enumerate(preprocessors):
train_path = DATA_DIR+'train_'+preprocessor+'_preprocessed.pkl'
valid_path = DATA_DIR+'valid_'+preprocessor+'_preprocessed.pkl'
tottrain_path = DATA_DIR+'tottrain_'+preprocessor+'_preprocessed.pkl'
test_path = DATA_DIR+'test_'+preprocessor+'_preprocessed.pkl'
if not os.path.exists(train_path) or not os.path.exists(valid_path) or not os.path.exists(test_path):
print('I cannot find something related to preprocessor: ' + preprocessor)
else:
if tot:
trainset = serial.load(tottrain_path)
else:
trainset = serial.load(train_path)
validset = serial.load(valid_path)
testset = serial.load(test_path)
if ii==0:
tottrainset = trainset
totvalidset = validset
tottestset = testset
else:
tottrainset.X = np.append(tottrainset.X, trainset.X, axis=0)
tottrainset.y = np.append(tottrainset.y, trainset.y, axis=0)
return tottrainset, totvalidset, tottestset
开发者ID:gaoch023,项目名称:kaggle,代码行数:26,代码来源:digits_data.py
示例10: __init__
def __init__(self, soln_path, save_path, black_sheep_path):
self.__dict__.update(locals())
del self.self
soln = serial.load(soln_path)
self.soln = soln.get_param_vector()
black_sheep = serial.load(black_sheep_path)
self.black_sheep = black_sheep.get_param_vector()
开发者ID:cc13ny,项目名称:galatea,代码行数:9,代码来源:__init__.py
示例11: main
def main():
data_dir = string.preprocess('${PYLEARN2_DATA_PATH}/stl10')
print('Loading STL10-10 unlabeled and train datasets...')
downsampled_dir = data_dir + '/stl10_32x32'
data = serial.load(downsampled_dir + '/unlabeled.pkl')
supplement = serial.load(downsampled_dir + '/train.pkl')
print('Concatenating datasets...')
data.set_design_matrix(np.concatenate((data.X, supplement.X), axis=0))
del supplement
print("Preparing output directory...")
patch_dir = data_dir + '/stl10_patches_8x8'
serial.mkdir(patch_dir)
README = open(patch_dir + '/README', 'w')
README.write(textwrap.dedent("""
The .pkl files in this directory may be opened in python using
cPickle, pickle, or pylearn2.serial.load.
data.pkl contains a pylearn2 Dataset object defining an unlabeled
dataset of 2 million 6x6 approximately whitened, contrast-normalized
patches drawn uniformly at random from a downsampled (to 32x32)
version of the STL-10 train and unlabeled datasets.
preprocessor.pkl contains a pylearn2 Pipeline object that was used
to extract the patches and approximately whiten / contrast normalize
them. This object is necessary when extracting features for
supervised learning or test set classification, because the
extracted features must be computed using inputs that have been
whitened with the ZCA matrix learned and stored by this Pipeline.
They were created with the pylearn2 script make_stl10_patches.py.
All other files in this directory, including this README, were
created by the same script and are necessary for the other files
to function correctly.
"""))
README.close()
print("Preprocessing the data...")
pipeline = preprocessing.Pipeline()
pipeline.items.append(preprocessing.ExtractPatches(patch_shape=(8, 8),
num_patches=2*1000*1000))
pipeline.items.append(
preprocessing.GlobalContrastNormalization(sqrt_bias=10., use_std=True))
pipeline.items.append(preprocessing.ZCA())
data.apply_preprocessor(preprocessor=pipeline, can_fit=True)
data.use_design_loc(patch_dir + '/data.npy')
serial.save(patch_dir + '/data.pkl', data)
serial.save(patch_dir + '/preprocessor.pkl', pipeline)
开发者ID:123fengye741,项目名称:pylearn2,代码行数:57,代码来源:make_stl10_patches_8x8.py
示例12: train_yaml
def train_yaml(yaml_file):
train = yaml_parse.load(yaml_file)
mode1 = serial.load(os.environ["MMDAErbms"] + "/laser_best.pkl")
mode2 = serial.load(os.environ["MMDAErbms"] + "/command_best.pkl")
deep = serial.load(os.environ["MMDAErbms"] + "/laser_command_best.pkl")
models = [mode1, deep, deep, mode1, mode2]
layers = list()
f = theano.config.floatX
for ii, layer in enumerate(train.model.layers):
if type(layer) is FlattenerLayer:
for l in layer.raw_layer.layers:
layers.append(l)
elif type(layer) is SplitterLayer:
layers.append(layer.raw_layer)
else:
layers.append(layer)
for ii, (layer, model) in enumerate(zip(layers, models)):
if ii < len(layers) / 2:
if type(layer) is Sigmoid:
if layer.get_weights().shape != model.get_weights():
layer.set_weights(model.get_weights()[: layer.get_weights().shape[0], :].astype(f))
else:
layer.set_weights(model.get_weights().astype(f))
if len(model.get_param_values()) == 4:
layer.set_biases(model.get_param_values()[3].astype(f))
else:
layer.set_biases(model.get_param_values()[2].astype(f))
else:
if type(layer) is Sigmoid:
if layer.enc_layer is None:
layer.set_weights(model.get_weights().transpose().astype(f))
layer.set_biases(model.get_param_values()[0].astype(f))
elif type(layer) is LinearGaussian:
params = model.get_param_values()
if layer.enc_layer is None:
layer.set_weights(params[2].transpose().astype(f))
layer.set_biases(params[1].astype(f))
beta = model.get_params()[0].eval()
if isinstance(beta, N.ndarray):
layer.beta.set_value(model.get_params()[0].eval().astype(f))
elif isinstance(beta, theano.sandbox.cuda.type.CudaNdarrayType):
layer.beta.set_value(model.get_params()[0].eval().dtype(f))
del models
del mode1
del mode2
del deep
train.main_loop()
开发者ID:jamessergeant,项目名称:MMDAEforControl,代码行数:56,代码来源:trainLCSMAE.py
示例13: extract_data
def extract_data(task_0, task_1):
model = serial.load(task_0)
num_params = num_parameters(model)
valid_0 = model.monitor.channels['valid_y_misclass'].val_record[-1]
model = serial.load(task_1)
valid_1 = model.monitor.channels['valid_both_y_misclass'].val_record[-1]
return num_params, float(valid_0), float(valid_1)
开发者ID:goodfeli,项目名称:forgetting,代码行数:10,代码来源:size_plot.py
示例14: load_from_numpy
def load_from_numpy(self, filename_root, mmap_mode='r'):
# Load the data
inputs = serial.load(filename_root+'_inputs.npy')
labels = serial.load(filename_root+'_labels.npy')
# Quick checks to ensure a proper dataset has been loaded
assert inputs.shape == (62000, 784)
assert labels.shape[0] == inputs.shape[0]
return inputs, labels
开发者ID:gdesjardins,项目名称:pylearn,代码行数:10,代码来源:mnist_variations.py
示例15: _load_data
def _load_data(self, which_set, phone):
"""
Load the TIMIT data from disk.
Parameters
----------
which_set : str
Subset of the dataset to use (either "train", "valid" or "test")
"""
# Check which_set
if which_set not in ['train', 'valid', 'test']:
raise ValueError(which_set + " is not a recognized value. " +
"Valid values are ['train', 'valid', 'test'].")
# Create file paths
timit_base_path = os.path.join(os.environ["PYLEARN2_DATA_PATH"],
"timit/readable")
speaker_info_list_path = os.path.join(timit_base_path, "spkrinfo.npy")
phonemes_list_path = os.path.join(timit_base_path,
"reduced_phonemes.pkl")
words_list_path = os.path.join(timit_base_path, "words.pkl")
speaker_features_list_path = os.path.join(timit_base_path,
"spkr_feature_names.pkl")
speaker_id_list_path = os.path.join(timit_base_path,
"speakers_ids.pkl")
if phone == 'full':
raw_wav_path = os.path.join(timit_base_path, which_set + "_x_raw.npy")
else:
raw_wav_path = os.path.join("/u/kimtaeho/Documents/2_Project/speech_synthesis/code/datasets/",which_set + "_wav_"+phone+".npy")
phonemes_path = os.path.join(timit_base_path,
which_set + "_x_phonemes.npy")
phones_path = os.path.join(timit_base_path,
which_set + "_x_phones.npy")
words_path = os.path.join(timit_base_path, which_set + "_x_words.npy")
speaker_path = os.path.join(timit_base_path,
which_set + "_spkr.npy")
# Load data. For now most of it is not used, as only the acoustic
# samples are provided, but this is bound to change eventually.
# Global data
if not self.audio_only:
self.speaker_info_list = serial.load(
speaker_info_list_path
).tolist().toarray()
self.speaker_id_list = serial.load(speaker_id_list_path)
self.speaker_features_list = serial.load(speaker_features_list_path)
self.words_list = serial.load(words_list_path)
self.phonemes_list = serial.load(phonemes_list_path)
# Set-related data
self.raw_wav = serial.load(raw_wav_path)
if not self.audio_only:
self.phonemes = serial.load(phonemes_path)
self.phones = serial.load(phones_path)
self.words = serial.load(words_path)
self.speaker_id = numpy.asarray(serial.load(speaker_path), 'int')
开发者ID:ktho22,项目名称:speech_synthesis,代码行数:55,代码来源:timit_per_phone.py
示例16: get_checkpoint
def get_checkpoint(self):
try:
checkpoint = self.file_prefix + ".pkl"
model = serial.load(checkpoint)
except IOError:
checkpoint = self.file_prefix + "_best.pkl"
model = serial.load(checkpoint)
except IOError:
return None
return model
开发者ID:ecastrow,项目名称:pl2mind,代码行数:11,代码来源:jobman_analysis.py
示例17: __init__
def __init__(self, lat, lon):
if not os.path.exists(MLP_FILE % (lat, lon)):
raise OSError
if not os.path.exists(LASSO_FILE % (lat, lon)):
raise OSError
self.mlp_dropout_model = serial.load(MLP_DROPOUT_FILE % (lat, lon))
self.mlp_model = serial.load(MLP_FILE % (lat, lon))
self.lasso_model = pickle.load(open(LASSO_FILE % (lat, lon), 'r'))
self.test_data = load_data.load_supervised(1986, 1999, lat, lon, 50, which='test')
self.lat = lat
self.lon = lon
开发者ID:tjvandal,项目名称:deeply-downscaling,代码行数:12,代码来源:batch_compare.py
示例18: __init__
def __init__(self, lock, modelpath = "/home/vartiai6/Autoproject/4layermaxoutbest.mdl",
preprocessorpath = "/home/vartiai6/Autoproject/4layermaxoutpreprocessorbest.pkl"):
self.lock = lock
self.modelpath = modelpath
self.model = serial.load(modelpath)
self.preprocessorpath = preprocessorpath
self.preprocessor = serial.load(preprocessorpath)
X = self.model.get_input_space().make_theano_batch()
Y = self.model.fprop( X )
self.f = function( [X], Y )
开发者ID:TeMaVa,项目名称:Annotator,代码行数:12,代码来源:AutoNetWrapper.py
示例19: _load_batch_cifar10pre
def _load_batch_cifar10pre(dtype='float64'):
"""
load a batch in the CIFAR-10 format
"""
preproc = os.path.join(data_dir_cifar10pre, "preprocessor.pkl")
preprocessor = serial.load(preproc)
train = os.path.join(data_dir_cifar10pre, "train.pkl")
train_set = ZCA_Dataset(preprocessed_dataset=serial.load(train), preprocessor = preprocessor, start=0, stop = 50000)
test = os.path.join(data_dir_cifar10pre, "test.pkl")
test_set = ZCA_Dataset(preprocessed_dataset= serial.load(test), preprocessor = preprocessor)
return train_set, test_set
开发者ID:Thalnos,项目名称:8bit-deep-learning,代码行数:12,代码来源:load3.py
示例20: _load_data
def _load_data(self, data_path, which_set):
"""
Load the TIMIT data from disk.
Parameters
----------
which_set : str
Subset of the dataset to use (either "train", "valid" or "test")
"""
# Check which_set
if which_set not in ['train', 'valid', 'test']:
raise ValueError(which_set + " is not a recognized value. " +
"Valid values are ['train', 'valid', 'test'].")
# Create file paths
timit_base_path = os.path.join(data_path, "timit/readable")
speaker_info_list_path = os.path.join(timit_base_path, "spkrinfo.npy")
phonemes_list_path = os.path.join(timit_base_path,
"reduced_phonemes.pkl")
words_list_path = os.path.join(timit_base_path, "words.pkl")
speaker_features_list_path = os.path.join(timit_base_path,
"spkr_feature_names.pkl")
speaker_id_list_path = os.path.join(timit_base_path,
"speakers_ids.pkl")
raw_wav_path = os.path.join(timit_base_path, which_set + "_x_raw.npy")
phonemes_path = os.path.join(timit_base_path,
which_set + "_x_phonemes.npy")
phones_path = os.path.join(timit_base_path,
which_set + "_x_phones.npy")
words_path = os.path.join(timit_base_path, which_set + "_x_words.npy")
speaker_path = os.path.join(timit_base_path,
which_set + "_spkr.npy")
# Load data. For now most of it is not used, as only the acoustic
# samples are provided, but this is bound to change eventually.
# Global data
if not self.audio_only:
self.speaker_info_list = serial.load(
speaker_info_list_path
).tolist().toarray()
self.speaker_id_list = serial.load(speaker_id_list_path)
self.speaker_features_list = serial.load(speaker_features_list_path)
self.words_list = serial.load(words_list_path)
self.phonemes_list = serial.load(phonemes_list_path)
# Set-related data
self.raw_wav = serial.load(raw_wav_path)
if not self.audio_only:
self.phonemes = serial.load(phonemes_path)
self.phones = serial.load(phones_path)
self.words = serial.load(words_path)
self.speaker_id = np.asarray(serial.load(speaker_path), 'int')
开发者ID:twuilliam,项目名称:ift6266h14_wt,代码行数:51,代码来源:timit.py
注:本文中的pylearn2.utils.serial.load函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论