• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python patch_viewer.make_viewer函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中pylearn2.gui.patch_viewer.make_viewer函数的典型用法代码示例。如果您正苦于以下问题:Python make_viewer函数的具体用法?Python make_viewer怎么用?Python make_viewer使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了make_viewer函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: main

def main(model_path,
        data_path,
        split,
        **kwargs):

    model =  serial.load(model_path)

    raw_dataset = get_test_data()
    X = get_features(data_path, split, False)
    assert X.shape[0] == 8000

    size = 25
    for start in xrange(0,X.shape[0]-size,size):
        y = raw_dataset.y[start:start+size]
        pred_y = model.predict(X[start:start+size,:])

        wrong_mask = y != pred_y

        raw_X = raw_dataset.X[start:start+size,:]
        pv = make_viewer(raw_X / 127.5, rescale = False, is_color = True, activation = wrong_mask )
        pv.show()

        right = 0
        for i in xrange(y.shape[0]):
            if y[i] == pred_y[i]:
                right += 1
                print str(start+i)+': correct ('+raw_dataset.class_names[y[i]-1]+')'
            else:
                print str(start+i)+': mistook '+raw_dataset.class_names[y[i]-1]+' for '+raw_dataset.class_names[pred_y[i]-1]
        print 'accuracy this batch : ',float(right)/float(size)
        x = raw_input()
        if x == 'q':
            break
开发者ID:cc13ny,项目名称:galatea,代码行数:33,代码来源:fewer.py


示例2: show_sample_pairs

def show_sample_pairs(generator,Noise_Dim,data_obj,filename):
    if data_obj.pitch_scale:
        pitch_max = 1.0
    else:
        pitch_max = 108.0
    grid_shape = None

    input_noise = np.random.uniform(-1.0,1.0,(100, Noise_Dim))
    samples = generator.predict(input_noise)
    grid_shape = (10,20)
    matched = np.zeros((samples.shape[0] *2, samples.shape[1]))
    X = np.concatenate((data_obj.X_train,data_obj.X_val,data_obj.X_test),axis=0)
    X = X.reshape(X.shape[0],X.shape[1]*X.shape[2])
    for i in xrange(samples.shape[0]):
        matched[2*i,:] = samples[i,:].copy()
        dists = np.square(X - samples[i,:]).sum(axis = 1)
        j = np.argmin(dists)
        matched[2*i+1,:] = X[j,:]
    samples = matched
    is_color = False

    samples = patch_quantize_01(patch_thresholding(samples/pitch_max))
    samples = samples * 2.0 - 1.0
    viewer = make_viewer(samples, grid_shape=grid_shape,patch_shape=(4,samples.shape[-1]/4),\
                          is_color=is_color,rescale=False)
    print "Saving %s ..."%filename
    viewer.save(filename)
开发者ID:lucaskingjade,项目名称:GAN_Music,代码行数:27,代码来源:visualize_music.py


示例3: visualize

def visualize (imgs, prefix , is_color = False ):          
    raster = []
    count = 0 
    if is_color is True and imgs.shape[3] % 3 != 0:
        filts = numpy.floor( imgs.shape[3] / 3)
        imgs = imgs[:,:,:,0:filts]
    
    for i in xrange (imgs.shape[3]):
        curr_image = imgs[:,:,:,i]
        if is_color is True:
            raster.append(rgb2gray(numpy.array(make_viewer( curr_image.reshape((curr_image.shape[0],curr_image.shape[1] * curr_image.shape[2])), is_color = False ).get_img())))
            if count == 2:          
                cv2.imwrite(prefix + str(i) + ".jpg", gray2rgb(raster[i-2],raster[i-1],raster[i]) )
                count = -1                            
        else:   
            raster.append(numpy.array(make_viewer( curr_image.reshape((curr_image.shape[0],curr_image.shape[1] * curr_image.shape[2])), is_color = False ).get_img()))             
            cv2.imwrite(prefix + str(i) + ".jpg",raster[i])
            
        count = count + 1
    return raster
开发者ID:seeviewer,项目名称:Convolutional-Neural-Networks,代码行数:20,代码来源:util.py


示例4: get_mat_product_viewer

def get_mat_product_viewer(W1, W2):
    """
    Show the matrix product of 2 layers.

    Parameters
    ----------
    W1: list
        First hidden layer.
    W2: list
        Second hidden layer.
    out_prefix: str
        Path where to save image.
    """
    prod = np.dot(W1, W2)
    pv = make_viewer(prod.T)

    return pv
开发者ID:123fengye741,项目名称:pylearn2,代码行数:17,代码来源:top_filters.py


示例5: get_weights_report

def get_weights_report(model_path = None, model = None, rescale = 'individual', border = False, norm_sort = False,
        dataset = None):
    """
        Returns a PatchViewer displaying a grid of filter weights

        Parameters:
            model_path: the filepath of the model to make the report on.
            rescale: a string specifying how to rescale the filter images
                        'individual' (default): scale each filter so that it
                            uses as much as possible of the dynamic range
                            of the display under the constraint that 0
                            is gray and no value gets clipped
                        'global' : scale the whole ensemble of weights
                        'none' :   don't rescale
            dataset: a Dataset object to do view conversion for displaying the weights.
                    if not provided one will be loaded from the model's dataset_yaml_src
    """

    if model is None:
        print 'making weights report'
        print 'loading model'
        model = serial.load(model_path)
        print 'loading done'
    else:
        assert model_path is None
    assert model is not None

    if rescale == 'none':
        global_rescale = False
        patch_rescale = False
    elif rescale == 'global':
        global_rescale = True
        patch_rescale = False
    elif rescale == 'individual':
        global_rescale = False
        patch_rescale = True
    else:
        raise ValueError('rescale='+rescale+", must be 'none', 'global', or 'individual'")


    if isinstance(model, dict):
        #assume this was a saved matlab dictionary
        del model['__version__']
        del model['__header__']
        del model['__globals__']
        weights ,= model.values()

        norms = np.sqrt(np.square(weights).sum(axis=1))
        print 'min norm: ',norms.min()
        print 'mean norm: ',norms.mean()
        print 'max norm: ',norms.max()

        return patch_viewer.make_viewer(weights, is_color = weights.shape[1] % 3 == 0)

    weights_view = None
    W = None

    W0,W1,_ = model.get_weights()
    G = model.groups
    

    weights_format = ('v', 'g', 'h')

    W1 = W1.T
    W0 = W0.T
    h1 = W1.shape[0]
    h0 = W0.shape[0]
    print W0.shape, W1.shape

    weights_view1 = dataset.get_weights_view(W1)
    weights_view0 = dataset.get_weights_view(W0)

    hr1 = int(np.ceil(np.sqrt(h1)))
    hc1 = hr1
    
    pv1 = patch_viewer.PatchViewer(grid_shape=(hr1,hc1), patch_shape=weights_view1.shape[1:3],
            is_color = weights_view1.shape[-1] == 3)
    
    hr0 = G.shape[0]
    hc0 = G.sum(1).max()
    
    pv0 = patch_viewer.PatchViewer(grid_shape=(hr0,hc0), patch_shape=weights_view0.shape[1:3],
            is_color = weights_view0.shape[-1] == 3)
            
    null_patch = np.zeros(weights_view0.shape[1:3])

    if border:
        act = 0
    else:
        act = None

    for i in range(0,h1):
        patch = weights_view1[i,...]
        pv1.add_patch( patch, rescale = patch_rescale, activation = act)
        
    for i in range(0,hr0):
        weights_view = weights_view0[i,...]
        g = 0
        for j in range(0, G.shape[1]):
            if G[i,j] == 1:
#.........这里部分代码省略.........
开发者ID:nicholas-leonard,项目名称:delicious,代码行数:101,代码来源:get_overlaping_grouped_weights_report.py


示例6: make_viewer

            continue
        else:
            print "examining this element"
            final = elem

    try:
        print "Trying get_weights topo"
        topo = final.get_weights_topo()
        print "It worked"
        success = True
    except Exception:
        pass

    if success:
        print "Making the viewer and showing"
        make_viewer(topo).show()
        quit()

    try:
        print "Trying get_weights"
        weights = final.get_weights()
        print "It worked"
        success = True
    except NotImplementedError:
        i -= 1 # skip over SpaceConverter, etc.
print "Out of the while loop"


print "weights shape ", weights.shape
viewer = make_viewer(weights, is_color=weights.shape[1] % 3 == 0 and weights.shape[1] != 48*48)
print "image shape ", viewer.image.shape
开发者ID:AdityoSanjaya,项目名称:adversarial,代码行数:31,代码来源:show_gen_weights.py


示例7: int

from pylearn2.utils import serial
import sys
_, model_path = sys.argv
model = serial.load(model_path)
from pylearn2.gui.patch_viewer import make_viewer
space = model.generator.get_output_space()
total_dimension = space.get_total_dimension()
import numpy as np
num_colors = 1
#if total_dimension % 3 == 0:
#    num_colors = 3
w = int(np.sqrt(total_dimension / num_colors))
from pylearn2.space import Conv2DSpace
desired_space = Conv2DSpace(shape=[w, w], num_channels=num_colors, axes=('b',0,1,'c'))
samples = space.format_as(batch=model.generator.sample(100),
        space=desired_space).eval()
print (samples.min(), samples.mean(), samples.max())
viewer = make_viewer(samples * 2.0 - 1.0)
viewer.show()
开发者ID:AdityoSanjaya,项目名称:adversarial,代码行数:19,代码来源:show_samples_tfd.py


示例8: get_weights_report

def get_weights_report(model_path=None,
                       model=None,
                       rescale='individual',
                       border=False,
                       norm_sort=False,
                       dataset=None):
    """
    Returns a PatchViewer displaying a grid of filter weights

    Parameters
    ----------
    model_path : str
        Filepath of the model to make the report on.
    rescale : str
        A string specifying how to rescale the filter images:
            - 'individual' (default) : scale each filter so that it
                  uses as much as possible of the dynamic range
                  of the display under the constraint that 0
                  is gray and no value gets clipped
            - 'global' : scale the whole ensemble of weights
            - 'none' :   don't rescale
    dataset : pylearn2.datasets.dataset.Dataset
        Dataset object to do view conversion for displaying the weights. If
        not provided one will be loaded from the model's dataset_yaml_src.

    Returns
    -------
    WRITEME
    """

    if model is None:
        logger.info('making weights report')
        logger.info('loading model')
        model = serial.load(model_path)
        logger.info('loading done')
    else:
        assert model_path is None
    assert model is not None

    if rescale == 'none':
        global_rescale = False
        patch_rescale = False
    elif rescale == 'global':
        global_rescale = True
        patch_rescale = False
    elif rescale == 'individual':
        global_rescale = False
        patch_rescale = True
    else:
        raise ValueError('rescale=' + rescale +
                         ", must be 'none', 'global', or 'individual'")


    if isinstance(model, dict):
        #assume this was a saved matlab dictionary
        del model['__version__']
        del model['__header__']
        del model['__globals__']
        keys = [key for key in model \
                if hasattr(model[key], 'ndim') and model[key].ndim == 2]
        if len(keys) > 2:
            key = None
            while key not in keys:
                logger.info('Which is the weights?')
                for key in keys:
                    logger.info('\t{0}'.format(key))
                key = input()
        else:
            key, = keys
        weights = model[key]

        norms = np.sqrt(np.square(weights).sum(axis=1))
        logger.info('min norm: {0}'.format(norms.min()))
        logger.info('mean norm: {0}'.format(norms.mean()))
        logger.info('max norm: {0}'.format(norms.max()))

        return patch_viewer.make_viewer(weights,
                                        is_color=weights.shape[1] % 3 == 0)

    weights_view = None
    W = None

    try:
        weights_view = model.get_weights_topo()
        h = weights_view.shape[0]
    except NotImplementedError:

        if dataset is None:
            logger.info('loading dataset...')
            control.push_load_data(False)
            dataset = yaml_parse.load(model.dataset_yaml_src)
            control.pop_load_data()
            logger.info('...done')

        try:
            W = model.get_weights()
        except AttributeError as e:
            reraise_as(AttributeError("""
Encountered an AttributeError while trying to call get_weights on a model.
This probably means you need to implement get_weights for this model class,
#.........这里部分代码省略.........
开发者ID:123fengye741,项目名称:pylearn2,代码行数:101,代码来源:get_weights_report.py


示例9: sharedX

_, model_path = sys.argv
model = serial.load(model_path)
from pylearn2.gui.patch_viewer import make_viewer
space = model.generator.get_output_space()
from pylearn2.config import yaml_parse
import numpy as np

dataset = yaml_parse.load(model.dataset_yaml_src)
dataset = dataset.get_test_set()

grid_shape = None

from pylearn2.utils import sharedX
X = sharedX(dataset.get_batch_topo(100))
samples, ignore = model.generator.inpainting_sample_and_noise(X)
samples = samples.eval()
total_dimension = space.get_total_dimension()
num_colors = 1
if total_dimension % 3 == 0:
    num_colors = 3
w = int(np.sqrt(total_dimension / num_colors))
from pylearn2.space import Conv2DSpace
desired_space = Conv2DSpace(shape=[w, w], num_channels=num_colors, axes=('b',0,1,'c'))
is_color = samples.shape[-1] == 3
print (samples.min(), samples.mean(), samples.max())
# Hack for detecting MNIST [0, 1] values. Otherwise we assume centered images
if samples.min() >0:
    samples = samples * 2.0 - 1.0
viewer = make_viewer(samples, grid_shape=grid_shape, is_color=is_color)
viewer.show()
开发者ID:AdityoSanjaya,项目名称:adversarial,代码行数:30,代码来源:show_inpaint_samples.py


示例10: CIFAR10

    print 'loading dataset'
    if cifar10:
        print 'CIFAR10 detected'
        dataset = CIFAR10(which_set = "train")
    elif cifar100:
        print 'CIFAR100 detected'
        dataset = CIFAR100(which_set = 'train')
    elif stl10:
        print 'STL10 detected'
        dataset = serial.load('${PYLEARN2_DATA_PATH}/stl10/stl10_32x32/train.pkl')
    X = dataset.get_design_matrix()[batch_start:batch_start + batch_size,:]

    size = np.sqrt(model.nvis/3)

    if cifar10 or cifar100:
        pv1 = make_viewer( (X-127.5)/127.5, is_color = True, rescale = False)
    elif stl10:
        pv1 = make_viewer( X/127.5, is_color = True, rescale = False)

    dataset.set_design_matrix(X)

    patchifier = ExtractGridPatches( patch_shape = (size,size), patch_stride = (1,1) )


    if size == 8:
        if cifar10:
            pipeline = serial.load('${GOODFELI_TMP}/cifar10_preprocessed_pipeline_2M.pkl')
        elif stl10:
            assert False
    elif size ==6:
        if cifar10:
开发者ID:cc13ny,项目名称:galatea,代码行数:31,代码来源:feature_viewer.py


示例11: get_weights_report

def get_weights_report(model_path = None, model = None, rescale = 'individual', border = False, norm_sort = False,
        dataset = None):
    """
        Returns a PatchViewer displaying a grid of filter weights

        Parameters:
            model_path: the filepath of the model to make the report on.
            rescale: a string specifying how to rescale the filter images
                        'individual' (default): scale each filter so that it
                            uses as much as possible of the dynamic range
                            of the display under the constraint that 0
                            is gray and no value gets clipped
                        'global' : scale the whole ensemble of weights
                        'none' :   don't rescale
            dataset: a Dataset object to do view conversion for displaying the weights.
                    if not provided one will be loaded from the model's dataset_yaml_src
    """

    if model is None:
        print 'making weights report'
        print 'loading model'
        model = serial.load(model_path)
        print 'loading done'
    else:
        assert model_path is None
    assert model is not None

    if rescale == 'none':
        global_rescale = False
        patch_rescale = False
    elif rescale == 'global':
        global_rescale = True
        patch_rescale = False
    elif rescale == 'individual':
        global_rescale = False
        patch_rescale = True
    else:
        raise ValueError('rescale='+rescale+", must be 'none', 'global', or 'individual'")


    if isinstance(model, dict):
        #assume this was a saved matlab dictionary
        del model['__version__']
        del model['__header__']
        del model['__globals__']
        weights ,= model.values()

        norms = np.sqrt(np.square(weights).sum(axis=1))
        print 'min norm: ',norms.min()
        print 'mean norm: ',norms.mean()
        print 'max norm: ',norms.max()

        return patch_viewer.make_viewer(weights, is_color = weights.shape[1] % 3 == 0)

    weights_view = None
    W = None

    try:
        weights_view = model.get_weights_topo()
        h = weights_view.shape[0]
    except Exception, e:

        if dataset is None:
            print 'loading dataset...'
            control.push_load_data(False)
            dataset = yaml_parse.load(model.dataset_yaml_src)
            control.pop_load_data()
            print '...done'

        if hasattr(model,'get_weights'):
            W = model.get_weights()

        if 'weightsShared' in dir(model):
            W = model.weightsShared.get_value()

        if 'W' in dir(model):
            if hasattr(model.W,'__array__'):
                warnings.warn('model.W is an ndarray; I can figure out how to display this but that seems like a sign of a bad bug')
                W = model.W
            else:
                W = model.W.get_value()

        has_D = False
        if 'D' in dir(model):
            has_D = True
            D = model.D

        if 'enc_weights_shared' in dir(model):
            W = model.enc_weights_shared.get_value()


        if W is None:
            raise AttributeError('model does not have a variable with a name like "W", "weights", etc  that pylearn2 recognizes')
开发者ID:HaniAlmousli,项目名称:pylearn,代码行数:93,代码来源:get_weights_report.py


示例12: xrange

print 'loading dataset'
from pylearn2.config import yaml_parse
dataset = yaml_parse.load(model.dataset_yaml_src)

batch_size = 100
batches = 50

for i in xrange(batches):
    print 'batch ',i
    X = dataset.get_batch_design(batch_size)

    f(X)

H = ave_V_h.get_value()
S = H * ave_V_s.get_value()
G = ave_V_g.get_value()


from pylearn2.gui.patch_viewer import make_viewer

pv1 = make_viewer(S)
pv1.show()
pv2 = make_viewer(H)
pv2.show()
pv3 = make_viewer(G)
pv3.show()



开发者ID:cc13ny,项目名称:galatea,代码行数:26,代码来源:weighted_ave_of_inputs.py


示例13: range

i = 1
models = []
weights = []
Xs = []
Ys = []
encode_functs = []
decode_functs = []
while os.path.isfile(layerpath(i)):
	models.append(serial.load(layerpath(i)))
	I = models[i-1].get_input_space().make_theano_batch()	
	E = models[i-1].encode(I)
	encode_functs.append(theano.function( [I], E ))
	H = models[i-1].get_output_space().make_theano_batch()
	D = models[i-1].decode(H)
	decode_functs.append(theano.function( [H], D ))
	weights.append(models[i-1].get_weights())
	i += 1

l1_acts = np.zeros([weights[1].shape[1],weights[0].shape[0]])
for k in range(len(weights[1].T)):
	feature = np.zeros(len(weights[1].T))
	feature[k] = 1
	l2_acts = decode_functs[1](np.atleast_2d(feature.astype(np.dtype(np.float32))))
	l1_acts[k] = decode_functs[0](l2_acts)

pv = patch_viewer.make_viewer(l1_acts, patch_shape=[28,28])
pv.save("mnist_l2_weights_decoder.png")
#scipy.misc.imsave('mnist7_l1_w0.png',l1_act.reshape([28,28]))

开发者ID:Kazjon,项目名称:deep_creeval,代码行数:28,代码来源:sdae_show_weights_decoder.py


示例14: xrange

        nsample.set_value(temp.astype(floatX))

# Burnin of Markov chain.
for i in xrange(opts.burnin):
    model.sample_neg_func()

# Start actual sampling.
samples = numpy.zeros((opts.batch_size * opts.n, model.n_u[0]))
indices = numpy.arange(0, len(samples), opts.n)
energies = numpy.zeros(opts.batch_size * opts.n)

for t in xrange(opts.n):
    samples[indices,:] = e_nsamples0.get_value()
    # skip in between plotted samples
    for i in xrange(opts.skip):
        sample_neg_func()
    energies[indices] = compute_energy()
    indices += 1

# transform energies between 0 and 1
energies -= energies.min()
energies /= energies.max()

import pdb; pdb.set_trace()
img = make_viewer(samples,
                  (opts.batch_size, opts.n),
                  (opts.width, opts.height),
                  activation = energies,
                  is_color=opts.color)
img.show()
开发者ID:gdesjardins,项目名称:MFNG,代码行数:30,代码来源:sample.py


示例15: get_dataless_dataset

import sys
from pylearn2.utils import get_dataless_dataset
from pylearn2.utils import serial
import numpy as np
from pylearn2.gui.patch_viewer import make_viewer

ignore, model_path = sys.argv

model = serial.load(model_path)
dataset = get_dataless_dataset(model)

biases = model.visible_layer.get_biases()

biases = np.zeros((1,biases.shape[0]))+biases

print 'values: ',(biases.min(), biases.mean(), biases.max())

pv = make_viewer(biases)

pv.show()
开发者ID:cc13ny,项目名称:galatea,代码行数:20,代码来源:show_biases.py


示例16: xrange

    temp = numpy.random.randint(0,2, size=model.neg_g.get_value().shape)
    model.neg_g.set_value(temp.astype('float32'))
    temp = numpy.random.randint(0,2, size=model.neg_h.get_value().shape)
    model.neg_h.set_value(temp.astype('float32'))
    v_std = numpy.sqrt(1./softplus(model.beta.get_value()))
    temp = numpy.random.normal(0, v_std, size=model.neg_v.get_value().shape)
    model.neg_v.set_value(temp.astype('float32'))

# Burnin of Markov chain.
for i in xrange(opts.burnin):
    sample_neg_func()

# Start actual sampling.
samples = numpy.zeros((opts.batch_size * opts.n, model.n_v))
indices = numpy.arange(0, len(samples), opts.n)

idx = numpy.random.permutation(model.batch_size)[:opts.batch_size]
for t in xrange(opts.n):
    samples[indices,:] = model.neg_ev.get_value()[idx]
    # skip in between plotted samples
    print t
    for i in xrange(opts.skip):
        sample_neg_func()
    indices += 1

img = make_viewer(samples,
                  (opts.batch_size, opts.n),
                  (opts.height, opts.width),
                  is_color=opts.color)
img.show()
开发者ID:gdesjardins,项目名称:hossrbm_public,代码行数:30,代码来源:sample.py


示例17: print

from pylearn2.utils import serial

kmeans = serial.load('kmeans.pkl')

mu = kmeans.mu

print (mu.min(),mu.mean(),mu.max())

mu -= .5

mu *= 2

from pylearn2.gui.patch_viewer import make_viewer

pv = make_viewer(mu)

pv.show()
开发者ID:cc13ny,项目名称:galatea,代码行数:17,代码来源:throwaway.py


示例18: get_weights_report

def get_weights_report(model_path = None, model = None, rescale = 'individual', border = False, norm_sort = False,
        dataset = None):
    """
        Returns a PatchViewer displaying a grid of filter weights

        Parameters:
            model_path: the filepath of the model to make the report on.
            rescale: a string specifying how to rescale the filter images
                        'individual' (default): scale each filter so that it
                            uses as much as possible of the dynamic range
                            of the display under the constraint that 0
                            is gray and no value gets clipped
                        'global' : scale the whole ensemble of weights
                        'none' :   don't rescale
            dataset: a Dataset object to do view conversion for displaying the weights.
                    if not provided one will be loaded from the model's dataset_yaml_src
    """

    if model is None:
        print 'making weights report'
        print 'loading model'
        model = serial.load(model_path)
        print 'loading done'
    else:
        assert model_path is None

    if rescale == 'none':
        global_rescale = False
        patch_rescale = False
    elif rescale == 'global':
        global_rescale = True
        patch_rescale = False
    elif rescale == 'individual':
        global_rescale = False
        patch_rescale = True
    else:
        raise ValueError('rescale='+rescale+", must be 'none', 'global', or 'individual'")


    if isinstance(model, dict):
        #assume this was a saved matlab dictionary
        del model['__version__']
        del model['__header__']
        del model['__globals__']
        weights ,= model.values()

        return patch_viewer.make_viewer(weights, is_color = weights.shape[1] % 3 == 0)

    if dataset is None:
        print 'loading dataset...'
        control.push_load_data(False)
        dataset = yaml_parse.load(model.dataset_yaml_src)
        control.pop_load_data()
        print '...done'


    W = None

    if hasattr(model,'get_weights'):
        W = model.get_weights()

    if 'weightsShared' in dir(model):
        W = model.weightsShared.get_value()

    if 'W' in dir(model):
        if hasattr(model.W,'__array__'):
            warnings.warn('model.W is an ndarray; I can figure out how to display this but that seems like a sign of a bad bug')
            W = model.W
        else:
            W = model.W.get_value()

    has_D = False
    if 'D' in dir(model):
        has_D = True
        D = model.D

    if 'enc_weights_shared' in dir(model):
        W = model.enc_weights_shared.get_value()


    if W is None:
        raise AttributeError('model does not have a variable with a name like "W", "weights", etc  that pylearn2 recognizes')

    if len(W.shape) == 2:
        if hasattr(model,'get_weights_format'):
            weights_format = model.get_weights_format()
        if hasattr(model, 'weights_format'):
            weights_format = model.weights_format

        assert hasattr(weights_format,'__iter__')
        assert len(weights_format) == 2
        assert weights_format[0] in ['v','h']
        assert weights_format[1] in ['v','h']
        assert weights_format[0] != weights_format[1]

        if weights_format[0] == 'v':
            W = W.T
        h = W.shape[0]

        if norm_sort:
#.........这里部分代码省略.........
开发者ID:davyfeng,项目名称:pylearn,代码行数:101,代码来源:get_weights_report.py


示例19: get_weights_report

def get_weights_report(model_path = None, model = None, rescale = 'individual', border = False, norm_sort = False,
        dataset = None):
    """
        Returns a PatchViewer displaying a grid of filter weights

        Parameters:
            model_path: the filepath of the model to make the report on.
            rescale: a string specifying how to rescale the filter images
                        'individual' (default): scale each filter so that it
                            uses as much as possible of the dynamic range
                            of the display under the constraint that 0
                            is gray and no value gets clipped
                        'global' : scale the whole ensemble of weights
                        'none' :   don't rescale
            dataset: a Dataset object to do view conversion for displaying the weights.
                    if not provided one will be loaded from the model's dataset_yaml_src
    """

    if model is None:
        print 'making weights report'
        print 'loading model'
        model = serial.load(model_path)
        print 'loading done'
    else:
        assert model_path is None
    assert model is not None

    if rescale == 'none':
        global_rescale = False
        patch_rescale = False
    elif rescale == 'global':
        global_rescale = True
        patch_rescale = False
    elif rescale == 'individual':
        global_rescale = False
        patch_rescale = True
    else:
        raise ValueError('rescale='+rescale+", must be 'none', 'global', or 'individual'")


    if isinstance(model, dict):
        #assume this was a saved matlab dictionary
        del model['__version__']
        del model['__header__']
        del model['__globals__']
        weights ,= model.values()

        norms = np.sqrt(np.square(weights).sum(axis=1))
        print 'min norm: ',norms.min()
        print 'mean norm: ',norms.mean()
        print 'max norm: ',norms.max()

        return patch_viewer.make_viewer(weights, is_color = weights.shape[1] % 3 == 0)

    weights_view = None
    W = None

    try:
        weights_view = model.get_weights_topo()
        h = weights_view.shape[0]
    except NotImplementedError:

        if dataset is None:
            print 'loading dataset...'
            control.push_load_data(False)
            dataset = yaml_parse.load(model.dataset_yaml_src)
            control.pop_load_data()
            print '...done'

        try:
            W = model.get_weights()
        except AttributeError, e:
            raise AttributeError("""
Encountered an AttributeError while trying to call get_weights on a model.
This probably means you need to implement get_weights for this model class,
but look at the original exception to be sure.
If this is an older model class, it may have weights stored as weightsShared,
etc.
Original exception: """+str(e))
开发者ID:casperkaae,项目名称:pylearn2,代码行数:79,代码来源:get_weights_report.py


示例20: sharedX

_, model_path = sys.argv
from pylearn2.utils import serial
model = serial.load(model_path)
d = model.discriminator
import gc
del model
gc.collect()
from pylearn2.utils import sharedX
X = sharedX(d.get_input_space().get_origin_batch(1))
obj =  -d.fprop(X).sum()
from pylearn2.optimization.batch_gradient_descent import BatchGradientDescent as BGD
import theano.tensor as T
def norm_constraint(updates):
    assert X in updates
    updates[X] = updates[X] / (1e-7 + T.sqrt(T.sqr(X).sum()))
opt = BGD(objective=obj, params=[X], param_constrainers=[norm_constraint], conjugate=True, reset_conjugate=False,
        reset_alpha=False, line_search_mode='exhaustive', verbose=3, max_iter=20)
results = []
import numpy as np
rng = np.random.RandomState([1, 2, 3])
for i in xrange(10):
    X.set_value(rng.randn(*X.get_value().shape).astype(X.dtype) / 10.)
    opt.minimize()
    Xv = X.dimshuffle(3, 1, 2, 0).eval()
    results.append(Xv)
X = np.concatenate(results, axis=0)
from pylearn2.gui.patch_viewer import make_viewer
v = make_viewer(X)
v.show()

开发者ID:cc13ny,项目名称:galatea,代码行数:29,代码来源:realest.py



注:本文中的pylearn2.gui.patch_viewer.make_viewer函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python patch_viewer.PatchViewer类代码示例发布时间:2022-05-25
下一篇:
Python target_format.OneHotFormatter类代码示例发布时间:2022-05-25
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap