• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python target_format.OneHotFormatter类代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中pylearn2.format.target_format.OneHotFormatter的典型用法代码示例。如果您正苦于以下问题:Python OneHotFormatter类的具体用法?Python OneHotFormatter怎么用?Python OneHotFormatter使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。



在下文中一共展示了OneHotFormatter类的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: check_one_hot_formatter

 def check_one_hot_formatter(seed, max_labels, dtype, ncases):
     rng = numpy.random.RandomState(seed)
     fmt = OneHotFormatter(max_labels=max_labels, dtype=dtype)
     integer_labels = rng.random_integers(0, max_labels - 1, size=ncases)
     one_hot_labels = fmt.format(integer_labels)
     assert len(zip(*one_hot_labels.nonzero())) == ncases
     for case, label in enumerate(integer_labels):
         assert one_hot_labels[case, label] == 1
开发者ID:Alienfeel,项目名称:pylearn2,代码行数:8,代码来源:test_target_format.py


示例2: check_one_hot_formatter_symbolic

 def check_one_hot_formatter_symbolic(seed, max_labels, dtype, ncases):
     rng = numpy.random.RandomState(seed)
     fmt = OneHotFormatter(max_labels=max_labels, dtype=dtype)
     integer_labels = rng.random_integers(0, max_labels - 1, size=ncases)
     x = theano.tensor.vector(dtype='int64')
     y = fmt.theano_expr(x)
     f = theano.function([x], y)
     one_hot_labels = f(integer_labels)
     assert len(zip(*one_hot_labels.nonzero())) == ncases
     for case, label in enumerate(integer_labels):
         assert one_hot_labels[case, label] == 1
开发者ID:Alienfeel,项目名称:pylearn2,代码行数:11,代码来源:test_target_format.py


示例3: generate_datasets

def generate_datasets(inputs):
    targets = np.zeros(inputs.shape[0]).astype('int')
    targets[::2] = 1 # every second target is class 1 others class 0
    inputs[targets == 1] = inputs[targets == 1] + 1
    target_formatter = OneHotFormatter(2)
    targets_one_hot = target_formatter.format(targets)
    train_set = VolumetricDenseDesignMatrix(topo_view=inputs[0:50], 
        y=targets_one_hot[0:50], axes=('b', 0, 1, 2, 'c'))
    valid_set = VolumetricDenseDesignMatrix(topo_view=inputs[50:75], 
        y=targets_one_hot[50:75], axes=('b', 0, 1, 2, 'c'))
    test_set = VolumetricDenseDesignMatrix(topo_view=inputs[75:100], 
        y=targets_one_hot[75:100], axes=('b', 0, 1, 2, 'c'))
    return train_set, valid_set, test_set
开发者ID:robintibor,项目名称:pylearn3dconv,代码行数:13,代码来源:test_training.py


示例4: nll

 def nll(self, data):
     X, Y = data
     z = self.score(X)
     z = z - z.max(axis=1).dimshuffle(0, 'x')
     log_prob = z - T.log(T.exp(z).sum(axis=1).dimshuffle(0, 'x'))
     Y = OneHotFormatter(self.dict_size).theano_expr(Y)
     Y = Y.reshape((Y.shape[0], Y.shape[2]))
     #import ipdb
     #ipdb.set_trace()
     log_prob_of = (Y * log_prob).sum(axis=1)
     assert log_prob_of.ndim == 1
     rval = as_floatX(log_prob_of.mean())
     return - rval
开发者ID:Sandy4321,项目名称:lisa_intern,代码行数:13,代码来源:__init__.py


示例5: test_one_hot_formatter_simple

def test_one_hot_formatter_simple():
    def check_one_hot_formatter(seed, max_labels, dtype, ncases):
        rng = numpy.random.RandomState(seed)
        fmt = OneHotFormatter(max_labels=max_labels, dtype=dtype)
        integer_labels = rng.random_integers(0, max_labels - 1, size=ncases)
        one_hot_labels = fmt.format(integer_labels)
        assert len(list(zip(*one_hot_labels.nonzero()))) == ncases
        for case, label in enumerate(integer_labels):
            assert one_hot_labels[case, label] == 1

    rng = numpy.random.RandomState(0)
    for seed, dtype in enumerate(all_types):
        yield (check_one_hot_formatter, seed, rng.random_integers(1, 30), dtype, rng.random_integers(1, 100))
    fmt = OneHotFormatter(max_labels=10)
    assert fmt.format(numpy.zeros((1, 1), dtype="uint8")).shape == (1, 1, 10)
开发者ID:JesseLivezey,项目名称:pylearn2,代码行数:15,代码来源:test_target_format.py


示例6: _transform_single_channel_data

    def _transform_single_channel_data(self, X, y):
        windowed_X = np.reshape(X, (-1, self.window_size))
        windowed_y = np.reshape(y, (-1, self.window_size))

        # Format the target into proper format
        sum_y = np.sum(windowed_y, axis=1)
        sum_y[sum_y > 0] = 1

        # Duplicate the labels for all channels
        dup_y = np.tile(sum_y, self.n_channels)

        one_hot_formatter = OneHotFormatter(max_labels=self.n_classes)
        hot_y = one_hot_formatter.format(dup_y)

        return windowed_X, hot_y, None
开发者ID:akaraspt,项目名称:epilepsy-system,代码行数:15,代码来源:chbmit.py


示例7: __init__

    def __init__(self, space, rng=None):
        super(OneHotDistribution, self).__init__(space)

        self.dim = space.get_total_dimension()
        self.formatter = OneHotFormatter(self.dim, dtype=space.dtype)

        self.rng = RandomStreams() if rng is None else rng
开发者ID:HyoungWooPark,项目名称:adversarial,代码行数:7,代码来源:distributions.py


示例8: test_dtype_errors

def test_dtype_errors():
    # Try to call theano_expr with a bad label dtype.
    raised = False
    fmt = OneHotFormatter(max_labels=50)
    try:
        fmt.theano_expr(theano.tensor.vector(dtype=theano.config.floatX))
    except TypeError:
        raised = True
    assert raised

    # Try to call format with a bad label dtype.
    raised = False
    try:
        fmt.format(numpy.zeros(10, dtype='float64'))
    except TypeError:
        raised = True
    assert raised
开发者ID:Alienfeel,项目名称:pylearn2,代码行数:17,代码来源:test_target_format.py


示例9: check_one_hot_formatter

    def check_one_hot_formatter(seed, max_labels, dtype, ncases, nmultis):
        rng = numpy.random.RandomState(seed)
        fmt = OneHotFormatter(max_labels=max_labels, dtype=dtype)
        integer_labels = rng.random_integers(0, max_labels - 1, size=ncases * nmultis).reshape(ncases, nmultis)

        one_hot_labels = fmt.format(integer_labels, mode="merge")
        # n_ones was expected to be equal to ncases * nmultis if integer_labels
        # do not contain duplicated tags. (i.e., those labels like
        # [1, 2, 2, 3, 5, 6].) Because that we are not depreciating this kind
        # of duplicated labels, which allows different cases belong to
        # different number of classes, and those duplicated tags will only
        # activate one neuron in the k-hot representation, we need to use
        # numpy.unique() here to eliminate those duplications while counting
        # "1"s in the final k-hot representation.
        n_ones = numpy.concatenate([numpy.unique(l) for l in integer_labels])
        assert len(list(zip(*one_hot_labels.nonzero()))) == len(n_ones)
        for case, label in enumerate(integer_labels):
            assert numpy.sum(one_hot_labels[case, label]) == nmultis
开发者ID:JesseLivezey,项目名称:pylearn2,代码行数:18,代码来源:test_target_format.py


示例10: OneHotDistribution

class OneHotDistribution(Distribution):
    """Randomly samples from a distribution of one-hot vectors."""

    def __init__(self, space, rng=None):
        super(OneHotDistribution, self).__init__(space)

        self.dim = space.get_total_dimension()
        self.formatter = OneHotFormatter(self.dim, dtype=space.dtype)

        self.rng = RandomStreams() if rng is None else rng

    def sample(self, n):
        idxs = self.rng.random_integers((n, 1), low=0, high=self.dim - 1)
        return self.formatter.theano_expr(idxs, mode='concatenate')
开发者ID:HyoungWooPark,项目名称:adversarial,代码行数:14,代码来源:distributions.py


示例11: _transform_multi_channel_data

    def _transform_multi_channel_data(self, X, y):
        # Data partitioning
        parted_X, parted_y = self._partition_data(X=X, y=y, partition_size=self.window_size)
        transposed_X = np.transpose(parted_X, [0, 2, 1])
        converted_X = np.reshape(transposed_X, (transposed_X.shape[0],
                                                transposed_X.shape[1],
                                                1,
                                                transposed_X.shape[2]))

        # Create view converter
        view_converter = DefaultViewConverter(shape=self.sample_shape,
                                              axes=('b', 0, 1, 'c'))

        # Convert data into a design matrix
        view_converted_X = view_converter.topo_view_to_design_mat(converted_X)
        assert np.all(converted_X == view_converter.design_mat_to_topo_view(view_converted_X))

        # Format the target into proper format
        sum_y = np.sum(parted_y, axis=1)
        sum_y[sum_y > 0] = 1
        one_hot_formatter = OneHotFormatter(max_labels=self.n_classes)
        hot_y = one_hot_formatter.format(sum_y)

        return view_converted_X, hot_y, view_converter
开发者ID:akaraspt,项目名称:epilepsy-system,代码行数:24,代码来源:chbmit.py


示例12: ConditionalGeneratorTestCase

class ConditionalGeneratorTestCase(unittest.TestCase):
    def setUp(self):
        self.noise_dim = 10
        self.num_labels = 10

        self.condition_dtype = 'uint8'
        self.condition_space = VectorSpace(dim=self.num_labels, dtype=self.condition_dtype)
        self.condition_formatter = OneHotFormatter(self.num_labels, dtype=self.condition_dtype)
        self.condition_distribution = OneHotDistribution(self.condition_space)

        # TODO this nvis stuff is dirty. The ConditionalGenerator should handle it
        self.mlp_nvis = self.noise_dim + self.num_labels
        self.mlp_nout = 1

        # Set up model
        self.mlp = MLP(nvis=self.mlp_nvis, layers=[Linear(self.mlp_nout, 'out', irange=0.1)])
        self.G = ConditionalGenerator(input_condition_space=self.condition_space,
                                      condition_distribution=self.condition_distribution,
                                      noise_dim=self.noise_dim,
                                      mlp=self.mlp)

    def test_conditional_generator_input_setup(self):
        """Check that conditional generator correctly sets up composite
        input layer."""

        # Feedforward: We want the net to ignore the noise and simply
        # convert the one-hot vector to a number
        weights = np.concatenate([np.zeros((self.mlp_nout, self.noise_dim)),
                                  np.array(range(self.num_labels)).reshape((1, -1)).repeat(self.mlp_nout, axis=0)],
                                 axis=1).T.astype(theano.config.floatX)
        self.mlp.layers[0].set_weights(weights)

        inp = (T.matrix(), T.matrix(dtype=self.condition_dtype))
        f = theano.function(inp, self.G.mlp.fprop(inp))

        assert_array_equal(
            f(np.random.rand(self.num_labels, self.noise_dim).astype(theano.config.floatX),
              self.condition_formatter.format(np.array(range(self.num_labels)))),
            np.array(range(self.num_labels)).reshape(self.num_labels, 1))

    def test_sample_noise(self):
        """Test barebones noise sampling."""

        n = T.iscalar()
        cond_inp = self.condition_distribution.sample(n)
        sample_and_noise = theano.function([n], self.G.sample_and_noise(cond_inp, all_g_layers=True)[1])

        print sample_and_noise(15)
开发者ID:HyoungWooPark,项目名称:adversarial,代码行数:48,代码来源:test_conditional.py


示例13: setUp

    def setUp(self):
        self.noise_dim = 10
        self.num_labels = 10

        self.condition_dtype = 'uint8'
        self.condition_space = VectorSpace(dim=self.num_labels, dtype=self.condition_dtype)
        self.condition_formatter = OneHotFormatter(self.num_labels, dtype=self.condition_dtype)
        self.condition_distribution = OneHotDistribution(self.condition_space)

        # TODO this nvis stuff is dirty. The ConditionalGenerator should handle it
        self.mlp_nvis = self.noise_dim + self.num_labels
        self.mlp_nout = 1

        # Set up model
        self.mlp = MLP(nvis=self.mlp_nvis, layers=[Linear(self.mlp_nout, 'out', irange=0.1)])
        self.G = ConditionalGenerator(input_condition_space=self.condition_space,
                                      condition_distribution=self.condition_distribution,
                                      noise_dim=self.noise_dim,
                                      mlp=self.mlp)
开发者ID:HyoungWooPark,项目名称:adversarial,代码行数:19,代码来源:test_conditional.py


示例14: test_bad_arguments

def test_bad_arguments():
    # Make sure an invalid max_labels raises an error.
    raised = False
    try:
        fmt = OneHotFormatter(max_labels=-10)
    except ValueError:
        raised = True
    assert raised

    raised = False
    try:
        fmt = OneHotFormatter(max_labels='10')
    except ValueError:
        raised = True
    assert raised

    # Make sure an invalid dtype identifier raises an error.
    raised = False
    try:
        fmt = OneHotFormatter(max_labels=10, dtype='invalid')
    except TypeError:
        raised = True
    assert raised

    # Make sure an invalid ndim raises an error for format().
    fmt = OneHotFormatter(max_labels=10)
    raised = False
    try:
        fmt.format(numpy.zeros((2, 3), dtype='int32'))
    except ValueError:
        raised = True
    assert raised

    # Make sure an invalid ndim raises an error for theano_expr().
    raised = False
    try:
        fmt.theano_expr(theano.tensor.imatrix())
    except ValueError:
        raised = True
    assert raised
开发者ID:Alienfeel,项目名称:pylearn2,代码行数:40,代码来源:test_target_format.py


示例15: __init__

    def __init__(self, max_labels, dim, **kwargs):
        """
        Initialize an IndexSpace.

        Parameters
        ----------
        max_labels : int
            The number of possible classes/labels. This means that
            all labels should be < max_labels. Example: For MNIST
            there are 10 numbers and hence max_labels = 10.
        dim : int
            The number of indices in one space e.g. for MNIST there is
            one target label and hence dim = 1. If we have an n-gram
            of word indices as input to a neurel net language model, dim = n.
        kwargs: passes on to superclass constructor
        """

        super(IndexSpace, self).__init__(**kwargs)

        self.max_labels = max_labels
        self.dim = dim
        self.formatter = OneHotFormatter(self.max_labels)
开发者ID:TheDash,项目名称:pylearn2,代码行数:22,代码来源:__init__.py


示例16: TIMITlpc

import numpy
from pylearn2_timit.timitlpc import TIMITlpc
from pylearn2.space import CompositeSpace, VectorSpace, IndexSpace
from pylearn2.format.target_format import OneHotFormatter

valid = TIMITlpc("valid", frame_length=160, overlap=159, start=10, stop=11)

valid._iter_data_specs = (CompositeSpace((IndexSpace(dim=3,max_labels=61), VectorSpace(dim=10),)), ('phones', 'lpc_features'))

formatter = OneHotFormatter(max_labels=62)

f = lambda x: formatter.format(numpy.asarray(x, dtype=int), mode='merge')

#valid._iter_convert = [f, None]

it = valid.iterator(mode='random_uniform', batch_size=100, num_batches=100)











开发者ID:jfsantos,项目名称:ift6266h14,代码行数:16,代码来源:test_timit_dataset.py


示例17: __init__

    def __init__(self, which_set, onehot_dtype='uint8',
                 center=False, rescale=False, gcn=None,
                 start=None, stop=None, axes=('b', 0, 1, 'c'),
                 toronto_prepro=False, preprocessor=None):
        """Modified version of the CIFAR10 constructor which creates Y
        as one-hot vectors rather than simple indexes. This is super
        hacky. Sorry, Guido.."""

        # note: there is no such thing as the cifar10 validation set;
        # pylearn1 defined one but really it should be user-configurable
        # (as it is here)

        self.axes = axes

        # we define here:
        dtype = 'uint8'
        ntrain = 50000
        nvalid = 0  # artefact, we won't use it
        ntest = 10000

        # we also expose the following details:
        self.img_shape = (3, 32, 32)
        self.img_size = numpy.prod(self.img_shape)
        self.n_classes = 10
        self.label_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                            'dog', 'frog', 'horse', 'ship', 'truck']

        # prepare loading
        fnames = ['data_batch_%i' % i for i in range(1, 6)]
        datasets = {}
        datapath = os.path.join(
            string_utils.preprocess('${PYLEARN2_DATA_PATH}'),
            'cifar10', 'cifar-10-batches-py')
        for name in fnames + ['test_batch']:
            fname = os.path.join(datapath, name)
            if not os.path.exists(fname):
                raise IOError(fname + " was not found. You probably need to "
                              "download the CIFAR-10 dataset by using the "
                              "download script in "
                              "pylearn2/scripts/datasets/download_cifar10.sh "
                              "or manually from "
                              "http://www.cs.utoronto.ca/~kriz/cifar.html")
            datasets[name] = cache.datasetCache.cache_file(fname)

        lenx = numpy.ceil((ntrain + nvalid) / 10000.) * 10000
        x = numpy.zeros((lenx, self.img_size), dtype=dtype)
        y = numpy.zeros((lenx, 1), dtype=dtype)

        # load train data
        nloaded = 0
        for i, fname in enumerate(fnames):
            _logger.info('loading file %s' % datasets[fname])
            data = serial.load(datasets[fname])
            x[i * 10000:(i + 1) * 10000, :] = data['data']
            y[i * 10000:(i + 1) * 10000, 0] = data['labels']
            nloaded += 10000
            if nloaded >= ntrain + nvalid + ntest:
                break

        # load test data
        _logger.info('loading file %s' % datasets['test_batch'])
        data = serial.load(datasets['test_batch'])

        # process this data
        Xs = {'train': x[0:ntrain],
              'test': data['data'][0:ntest]}

        Ys = {'train': y[0:ntrain],
              'test': data['labels'][0:ntest]}

        X = numpy.cast['float32'](Xs[which_set])

        y = Ys[which_set]
        if isinstance(y, list):
            y = numpy.asarray(y).astype(dtype)
        if which_set == 'test':
            assert y.shape[0] == 10000
            y = y.reshape((y.shape[0], 1))

        formatter = OneHotFormatter(self.n_classes, dtype=onehot_dtype)
        y = formatter.format(y, mode='concatenate')

        if center:
            X -= 127.5
        self.center = center

        if rescale:
            X /= 127.5
        self.rescale = rescale

        if toronto_prepro:
            assert not center
            assert not gcn
            X = X / 255.
            if which_set == 'test':
                other = CIFAR10(which_set='train')
                oX = other.X
                oX /= 255.
                X = X - oX.mean(axis=0)
            else:
#.........这里部分代码省略.........
开发者ID:HyoungWooPark,项目名称:adversarial,代码行数:101,代码来源:datasets.py


示例18: OneHotFormatter

import theano
import numpy as np

dataset = yaml_parse.load(model.dataset_yaml_src)

grid_shape = None

# Number of choices for one-hot values
rows = model.generator.condition_space.get_total_dimension()

# Samples per condition
sample_cols = 5

# Generate conditional information
conditional_batch = model.generator.condition_space.make_theano_batch()
formatter = OneHotFormatter(rows,
                            dtype=model.generator.condition_space.dtype)
conditional = formatter.theano_expr(conditional_batch, mode='concatenate')

# Now sample from generator
# For some reason format_as from VectorSpace is not working right
topo_samples_batch = model.generator.sample(conditional)
topo_sample_f = theano.function([conditional], topo_samples_batch)
conditional_data = formatter.format(np.concatenate([np.repeat(i, sample_cols) for i in range(rows)])
                                      .reshape((rows * sample_cols, 1)),
                                    mode='concatenate')
topo_samples = topo_sample_f(conditional_data)

samples = dataset.get_design_matrix(topo_samples)
dataset.axes = ['b', 0, 1, 'c']
dataset.view_converter.axes = ['b', 0, 1, 'c']
topo_samples = dataset.get_topological_view(samples)
开发者ID:HyoungWooPark,项目名称:adversarial,代码行数:32,代码来源:show_samples_cifar_conditional.py


示例19: load_data


#.........这里部分代码省略.........
        # Leave-one-out cross-validation - seizure
        n_seizures = seizure_range_idx.shape[0]
        rest_seizure_idx = np.setdiff1d(np.arange(n_seizures), test_seizure_idx)
        perm_rest_seizure_idx = np.random.permutation(rest_seizure_idx)
        train_seizure_idx = perm_rest_seizure_idx
        cv_seizure_idx = perm_rest_seizure_idx

        # Leave-one-out cross-validation - non-seizure
        n_train_segments = int(n_segments * 0.6)
        n_cv_segments = int(n_segments * 0.2)
        non_seizure_segment_idx = np.arange(n_segments)
        perm_non_seizure_segment_idx = np.random.permutation(non_seizure_segment_idx)
        train_sample_segments = perm_non_seizure_segment_idx[:n_train_segments]
        cv_sample_segments = perm_non_seizure_segment_idx[n_train_segments:n_train_segments+n_cv_segments]
        test_sample_segments = perm_non_seizure_segment_idx[n_train_segments+n_cv_segments:]
        train_sample_idx = np.empty(0, dtype=int)
        for s in train_sample_segments:
            train_sample_idx = np.append(train_sample_idx, segment_idx[s])
        cv_sample_idx = np.empty(0, dtype=int)
        for s in cv_sample_segments:
            cv_sample_idx = np.append(cv_sample_idx, segment_idx[s])
        test_sample_idx = np.empty(0, dtype=int)
        for s in test_sample_segments:
            test_sample_idx = np.append(test_sample_idx, segment_idx[s])

        print 'Segment index for train, cv and test sets:', \
              train_sample_segments, cv_sample_segments, test_sample_segments

        print 'Seizure index for train, cv and test sets:', \
              train_seizure_idx, cv_seizure_idx, [test_seizure_idx]

        if which_set == 'train':
            print("Loading training data...")
            data = raw_data[:,non_seizure_round_sample_idx[train_sample_idx]]
            labels = raw_labels[non_seizure_round_sample_idx[train_sample_idx]]
            select_seizure = train_seizure_idx
        elif which_set == 'valid':
            print("Loading validation data...")
            data = raw_data[:,non_seizure_round_sample_idx[cv_sample_idx]]
            labels = raw_labels[non_seizure_round_sample_idx[cv_sample_idx]]
            select_seizure = cv_seizure_idx
        elif which_set == 'test':
            print("Loading test data...")
            data = raw_data[:,non_seizure_round_sample_idx[test_sample_idx]]
            labels = raw_labels[non_seizure_round_sample_idx[test_sample_idx]]
            select_seizure = [test_seizure_idx]
        elif which_set == 'all':
            print("Loading all data...")
            data = raw_data
            labels = raw_labels
            select_seizure = []
        else:
            raise('Invalid set.')

        # Add seizure data
        for sz in select_seizure:
            data = np.concatenate((data, raw_data[:, seizure_round_sample_idx[sz]]), axis=1)
            labels = np.concatenate((labels, raw_labels[seizure_round_sample_idx[sz]]), axis=1)

        # No filtering

        # Preprocessing
        if which_set == 'train':
            scaler = preprocessing.StandardScaler()
            scaler = scaler.fit(data.transpose())

            with open(scaler_path, 'w') as f:
                pickle.dump(scaler, f)

            data = scaler.transform(data.transpose()).transpose()
        else:
            with open(scaler_path) as f:
                scaler = pickle.load(f)

            data = scaler.transform(data.transpose()).transpose()

        # Input transformation
        X = np.reshape(data, (-1, sample_size))
        y = np.reshape(labels, (-1, sample_size))
        y = np.sum(y, 1).transpose()
        y[y > 0] = 1

        print 'Seizure index after transform:', np.where(y)[0]
        self.seizure_seconds = np.where(y)[0]

        # Duplicate the labels for all channels
        y = np.tile(y, n_channels)

        # Format the target into proper format
        n_classes = 2
        one_hot_formatter = OneHotFormatter(max_labels=n_classes)
        y = one_hot_formatter.format(y)

        # Check batch size
        cut_off = X.shape[0] % batch_size
        if cut_off > 0:
            X = X[:-cut_off,:]
            y = y[:-cut_off,:]

        return X, y, n_channels, sample_size
开发者ID:akaraspt,项目名称:epilepsy-system,代码行数:101,代码来源:epilepsiae.py


示例20: load_data


#.........这里部分代码省略.........
                        print ' sample_idx:', good_idx[nan_sample_idx], ' feature_idx:', nan_feature_idx
                        print ' shape before remove NaN:', temp_X.shape
                        tmp_preictal_idx = np.where(temp_y_withheld == 1)[0]
                        tmp_nonictal_idx = np.where(temp_y_withheld == 0)[0]
                        nan_preictal_sample_idx = np.intersect1d(tmp_preictal_idx, nan_sample_idx)
                        nan_nonictal_sample_idx = np.intersect1d(tmp_nonictal_idx, nan_sample_idx)
                        if nan_preictal_sample_idx.size > 0:
                            print ' NaN are in preictal index:', good_idx[nan_preictal_sample_idx]
                        if nan_nonictal_sample_idx.size > 0:
                            print ' NaN are in nonictal index:', good_idx[nan_nonictal_sample_idx]
                        all_idx = np.arange(temp_X.shape[1])
                        good_idx_1 = np.setdiff1d(all_idx, nan_sample_idx)
                        temp_X = temp_X[:, good_idx_1]
                        temp_y_all = temp_y_all[good_idx_1]
                        temp_y_withheld = temp_y_withheld[good_idx_1]
                        temp_ictal_labels = temp_ictal_labels[good_idx_1]
                        print ' shape before remove NaN:', temp_X.shape
                        self.nan_non_flat_samples = self.nan_non_flat_samples + nan_sample_idx.size

                    # Sanity check
                    tmp_nan_sample_idx = np.where(np.isnan(np.sum(temp_X, 0)))[0]
                    if tmp_nan_sample_idx.size > 0:
                        raise Exception('There is an error in removing NaN')
                    if not (temp_X.shape[1] == temp_y_all.size):
                        raise Exception('Number of feature data and labels [temp_y_all] are not equal.')
                    if not (temp_X.shape[1] == temp_y_withheld.size):
                        raise Exception('Number of feature data and labels [temp_y_withheld] are not equal.')
                    if not (temp_X.shape[1] == temp_ictal_labels.size):
                        raise Exception('Number of feature data and labels [ictal_labels] are not equal.')

                    if not (X is None) and not (y is None) and not (ictal_labels is None):
                        X = np.concatenate((X, temp_X), axis=1)
                        y = np.append(y, temp_y_withheld)
                        y_label_all = np.append(y_label_all, temp_y_all)
                        ictal_labels = np.append(ictal_labels, temp_ictal_labels)
                    else:
                        X = temp_X
                        y = temp_y_withheld
                        y_label_all = temp_y_all
                        ictal_labels = temp_ictal_labels
                else:
                    print 'There is no good segment for during this seizure'

            # Store preictal labels that are from the withheld index (use for compute accuracy), selected seizure index,
            #  and removed seizure index.
            # Note: this property will exist when which_set=='valid' or which_set=='test'
            #       as there is no need for ictal to be imported.
            self.y_label_all = y_label_all

            # Sanity check
            if np.where(y == 1)[0].size > np.where(y_label_all > 0)[0].size:
                raise Exception('There is an error in collecting preictal labels only from the leave-out-seizure index.')
            if np.where(y == 1)[0].size == np.where(y_label_all == 1)[0].size:
                print 'There is only one preictal periods, and this period is from the leave-out-seizure index.'
                if not np.all(np.where(y == 1)[0] == np.where(y_label_all == 1)[0]):
                    raise Exception('There is a mismatch between y and y_label_all.')
            if np.where(y == 1)[0].size < np.where(y_label_all > 0)[0].size:
                print 'There are more than one preictal periods.'
                if not np.all(np.where(y == 1)[0] == np.where(y_label_all == 1)[0]):
                    raise Exception('There is a mismatch between y_select_idx and y in the preictal labels of the leave-out-seizure index.')

            # Store ictal labels
            # Note: this property will exist when which_set=='valid' or which_set=='test'
            #       as there is no need for ictal to be imported.
            self.ictal_labels = ictal_labels
        else:
            raise Exception('Invalid dataset selection')

        print 'There are {0} samples that have been removed in addition to the flat signal as due to NaN.'.format(self.nan_non_flat_samples)

        X = np.transpose(X, [1, 0])
        one_hot_formatter = OneHotFormatter(max_labels=2)
        y = one_hot_formatter.format(y)

        # Sanity check
        # Note: We ignore nan_non_flat_samples if we load shuffle data as we specify the labels after the NaN have been removed
        #       In contrast to loading continuous data, we specify the labels before removing NaN, so we have to remove the NaN samples for checking
        if self.which_set == 'train' or self.which_set == 'valid_train':
            if not (X.shape[0] == self.preictal_samples + self.nonictal_samples):
                raise Exception('There is a mismatch in the number of training samples ({0} != {1}).'.format(X.shape[0],
                                                                                                             self.preictal_samples + self.nonictal_samples))
            if not (np.where(np.argmax(y, axis=1) == 1)[0].size == self.preictal_samples):
                raise Exception('There is a mismatch in the number of preictal samples and its labels ({0} != {1}).'.format(np.where(np.argmax(y, axis=1) == 1)[0].size,
                                                                                                                            self.preictal_samples))
            if not (X.shape[0] == y.shape[0]):
                raise Exception('There is a mismatch in the number of training samples and its labels ({0} != {1}).'.format(X.shape[0],
                                                                                                                            y.shape[0]))
        elif self.which_set == 'valid' or self.which_set == 'test':
            if not (X.shape[0] == self.preictal_samples + self.nonictal_samples - self.nan_non_flat_samples):
                raise Exception('There is a mismatch in the number of training samples ({0} != {1}).'.format(X.shape[0],
                                                                                                             self.preictal_samples + self.nonictal_samples - self.nan_non_flat_samples))
            if not ((np.where(np.argmax(y, axis=1) == 1)[0].size + np.where(np.argmax(y, axis=1) == 0)[0].size) ==
                        self.preictal_samples + self.nonictal_samples - self.nan_non_flat_samples):
                raise Exception('There is a mismatch in the number of samples and its labels ({0} != {1}).'.format(np.where(np.argmax(y, axis=1) == 1)[0].size + np.where(np.argmax(y, axis=1) == 0)[0].size,
                                                                                                                   self.preictal_samples))
            if not (X.shape[0] == y.shape[0]):
                raise Exception('There is a mismatch in the number of training samples and its labels ({0} != {1}).'.format(X.shape[0],
                                                                                                                            y.shape[0]))

        return X, y
开发者ID:akaraspt,项目名称:epilepsy-system,代码行数:101,代码来源:epilepsiae.py



注:本文中的pylearn2.format.target_format.OneHotFormatter类示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python patch_viewer.make_viewer函数代码示例发布时间:2022-05-25
下一篇:
Python pep8.StyleGuide类代码示例发布时间:2022-05-25
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap