• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Java DataSetIterator类代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Java中org.deeplearning4j.datasets.iterator.DataSetIterator的典型用法代码示例。如果您正苦于以下问题:Java DataSetIterator类的具体用法?Java DataSetIterator怎么用?Java DataSetIterator使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。



DataSetIterator类属于org.deeplearning4j.datasets.iterator包,在下文中一共展示了DataSetIterator类的8个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。

示例1: evaluate

import org.deeplearning4j.datasets.iterator.DataSetIterator; //导入依赖的package包/类
@Override
   @SuppressWarnings("rawtypes")
   public Model evaluate()
   {
final Evaluation evaluation = new Evaluation(parameters.getOutputSize());
try
{
    final DataSetIterator iterator = new MnistDataSetIterator(100, 10000);
    while (iterator.hasNext())
    {
	final DataSet testingData = iterator.next();
	evaluation.eval(testingData.getLabels(), model.output(testingData.getFeatureMatrix()));
    }

    System.out.println(evaluation.stats());
}
catch (IOException e)
{
    e.printStackTrace();
}
return this;
   }
 
开发者ID:amrabed,项目名称:DL4J,代码行数:23,代码来源:StackedAutoEncoderModel.java


示例2: train

import org.deeplearning4j.datasets.iterator.DataSetIterator; //导入依赖的package包/类
@Override
   public Model train()
   {
final DataSetIterator iterator = data.getIterator();
while (iterator.hasNext())
{
    DataSet next = iterator.next();
    model.fit(new DataSet(next.getFeatureMatrix(), next.getFeatureMatrix()));
}
return this;
   }
 
开发者ID:amrabed,项目名称:DL4J,代码行数:12,代码来源:DeepAutoEncoderModel.java


示例3: main

import org.deeplearning4j.datasets.iterator.DataSetIterator; //导入依赖的package包/类
public static void main(String[] args) throws Exception {
		final int numRows = 28;
		final int numColumns = 28;
		int outputNum = 10;
		int numSamples = 60000;
		int batchSize = 100;
		int iterations = 10;
		int seed = 123;
		int listenerFreq = batchSize / 5;

		log.info("Load data....");
		DataSetIterator iter = new MnistDataSetIterator(batchSize, numSamples,
				true);

		log.info("Build model....");
		 MultiLayerNetwork model = softMaxRegression(seed, iterations, numRows, numColumns, outputNum);
//		// MultiLayerNetwork model = deepBeliefNetwork(seed, iterations,
//		// numRows, numColumns, outputNum);
//		MultiLayerNetwork model = deepConvNetwork(seed, iterations, numRows,
//				numColumns, outputNum);

		model.init();
		model.setListeners(Collections
				.singletonList((IterationListener) new ScoreIterationListener(
						listenerFreq)));

		log.info("Train model....");
		model.fit(iter); // achieves end to end pre-training

		log.info("Evaluate model....");
		Evaluation eval = new Evaluation(outputNum);

		DataSetIterator testIter = new MnistDataSetIterator(100, 10000);
		while (testIter.hasNext()) {
			DataSet testMnist = testIter.next();
			INDArray predict2 = model.output(testMnist.getFeatureMatrix());
			eval.eval(testMnist.getLabels(), predict2);
		}

		log.info(eval.stats());
		log.info("****************Example finished********************");

	}
 
开发者ID:PacktPublishing,项目名称:Machine-Learning-End-to-Endguide-for-Java-developers,代码行数:44,代码来源:NeuralNetworks.java


示例4: main

import org.deeplearning4j.datasets.iterator.DataSetIterator; //导入依赖的package包/类
public static void main(String[] args) throws Exception {
     // final int numRows = 28;
     // final int numColumns = 28;
      int outputNum = 10;
int inputNum = 1000;
      int numSamples = 60000;
      int batchSize = 1024;
      int iterations = 10;
      int seed = 123;
      int listenerFreq = batchSize / 5;

      log.info("Load data....");
      DataSetIterator iter = new MnistDataSetIterator(batchSize,numSamples,true);

      log.info("Build model....");
      MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
         .seed(seed)
         .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue)
         .gradientNormalizationThreshold(1.0)
         .iterations(iterations)
         .momentum(0.5)
         .momentumAfter(Collections.singletonMap(3, 0.9))
         .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT)
         .list()
   .layer(0, new AutoEncoder.Builder()
		   .nIn(inputNum)
		   .nOut(500)
                 .weightInit(WeightInit.XAVIER).lossFunction(LossFunction.RMSE_XENT)
                 .corruptionLevel(0.3)
                 .build())
         .layer(1, new AutoEncoder.Builder()
                .nIn(500)
				.nOut(250)
                      .weightInit(WeightInit.XAVIER).lossFunction(LossFunction.RMSE_XENT)
                      .corruptionLevel(0.3)
                      .build())
         .layer(2, new AutoEncoder.Builder()
                .nIn(250)
				.nOut(125)
                      .weightInit(WeightInit.XAVIER).lossFunction(LossFunction.RMSE_XENT)
                      .corruptionLevel(0.3)
                      .build())
   .layer(3, new AutoEncoder.Builder()
                .nIn(125)
				.nOut(50)
                      .weightInit(WeightInit.XAVIER).lossFunction(LossFunction.RMSE_XENT)
                      .corruptionLevel(0.3)
                      .build())
         .layer(4, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD)
                .activation("softmax")
                      .nIn(75)
				.nOut(outputNum)
				.build())
         .pretrain(true)
   .backprop(false)
         .build();

      MultiLayerNetwork model = new MultiLayerNetwork(conf);
      model.init();
      model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(listenerFreq)));

      log.info("Train model....");
      model.fit(iter); 

      log.info("Evaluate model....");
      Evaluation eval = new Evaluation(outputNum);

      DataSetIterator testIter = new MnistDataSetIterator(100,10000);
      while(testIter.hasNext()) {
          DataSet testMnist = testIter.next();
          INDArray predict2 = model.output(testMnist.getFeatureMatrix());
          eval.eval(testMnist.getLabels(), predict2);
      }

      log.info(eval.stats());

  }
 
开发者ID:PacktPublishing,项目名称:Deep-Learning-with-Hadoop,代码行数:78,代码来源:StackedAutoEncoder.java


示例5: main

import org.deeplearning4j.datasets.iterator.DataSetIterator; //导入依赖的package包/类
public static void main(String[] args) throws IOException {
     
    Nd4j.MAX_SLICES_TO_PRINT = -1;
    Nd4j.MAX_ELEMENTS_PER_SLICE = -1;
    Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
    final int numRows = 4;
    final int numColumns = 1;
    int outputNum = 10;
    int numSamples = 150;
    int batchSize = 150;
    int iterations = 100;
    int seed = 123;
    int listenerFreq = iterations/2;

    log.info("Load data....");
    DataSetIterator iter = new IrisDataSetIterator(batchSize, numSamples);
    
    DataSet iris = iter.next();

    iris.normalizeZeroMeanZeroUnitVariance();

    log.info("Build model....");
    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().regularization(true)
            .miniBatch(true)
           
            .layer(new RBM.Builder().l2(1e-1).l1(1e-3)
                    .nIn(numRows * numColumns)  
                    .nOut(outputNum) 
                    .activation("relu")  
                    .weightInit(WeightInit.RELU)  
                    .lossFunction(LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY).k(3)
                    .hiddenUnit(HiddenUnit.RECTIFIED).visibleUnit(VisibleUnit.GAUSSIAN)
                    .updater(Updater.ADAGRAD).gradientNormalization(GradientNormalization.ClipL2PerLayer)
                    .build())
            .seed(seed)  
            .iterations(iterations)
            .learningRate(1e-3)  
            .optimizationAlgo(OptimizationAlgorithm.LBFGS)
            .build();
    Layer model = LayerFactories.getFactory(conf.getLayer()).create(conf);
    model.setListeners(new ScoreIterationListener(listenerFreq));

    log.info("Evaluate weights....");
    INDArray w = model.getParam(DefaultParamInitializer.WEIGHT_KEY);
    log.info("Weights: " + w);
    log.info("Scaling the dataset");
    iris.scale();
    log.info("Train model....");
    for(int i = 0; i < 20; i++) {
        log.info("Epoch "+i+":");
        model.fit(iris.getFeatureMatrix());
    }

}
 
开发者ID:PacktPublishing,项目名称:Deep-Learning-with-Hadoop,代码行数:55,代码来源:RBM.java


示例6: getIterator

import org.deeplearning4j.datasets.iterator.DataSetIterator; //导入依赖的package包/类
public DataSetIterator getIterator()
   {
return iterator;
   }
 
开发者ID:amrabed,项目名称:DL4J,代码行数:5,代码来源:Data.java


示例7: main

import org.deeplearning4j.datasets.iterator.DataSetIterator; //导入依赖的package包/类
public static void main(String[] args) throws Exception {
    SentenceIterator docIter = new CollectionSentenceIterator(new SentenceToPhraseMapper(new ClassPathResource("/train.tsv").getFile()).sentences());
    TokenizerFactory factory = new DefaultTokenizerFactory();
    Word2Vec  vec = new Word2Vec.Builder().iterate(docIter).tokenizerFactory(factory).batchSize(100000)
            .learningRate(2.5e-2).iterations(1)
            .layerSize(100).windowSize(5).build();
    vec.fit();

    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().nIn(vec.getLayerSize()).nOut(vec.getLayerSize())
            .hiddenUnit(RBM.HiddenUnit.RECTIFIED).visibleUnit(RBM.VisibleUnit.GAUSSIAN).momentum(0.5f)
            .iterations(10).learningRate(1e-6f).build();

    InMemoryLookupCache l = (InMemoryLookupCache) vec.getCache();

    DBN d = new DBN.Builder()
            .configure(conf).hiddenLayerSizes(new int[]{250,100,2})
            .build();
    DataSet dPretrain = new DataSet(l.getSyn0(),l.getSyn0());
    DataSetIterator dPretrainIter =  new ListDataSetIterator(dPretrain.asList(),1000);
    while(dPretrainIter.hasNext()) {
        d.pretrain(dPretrainIter.next().getFeatureMatrix(), 1, 1e-6f, 10);


    }

    // d.pretrain(l.getSyn0(),1,1e-3f,1000);
    d.getOutputLayer().conf().setLossFunction(LossFunctions.LossFunction.RMSE_XENT);

    SemanticHashing s = new SemanticHashing.Builder().withEncoder(d)
            .build();

    d = null;

    dPretrainIter.reset();
    while(dPretrainIter.hasNext()) {
        s.fit(dPretrainIter.next());

    }




    Tsne t = new Tsne.Builder()
            .setMaxIter(100).stopLyingIteration(20).build();

    INDArray output = s.reconstruct(l.getSyn0(),4);
    l.getSyn0().data().flush();
    l.getSyn1().data().flush();
    s = null;
    System.out.println(Arrays.toString(output.shape()));
    t.plot(output,2,new ArrayList<>(vec.getCache().words()));
    vec.getCache().plotVocab(t);

}
 
开发者ID:ihuerga,项目名称:deeplearning4j-nlp-examples,代码行数:55,代码来源:VisualizationSemanticHashing.java


示例8: main

import org.deeplearning4j.datasets.iterator.DataSetIterator; //导入依赖的package包/类
public static void main(String[] args) {
    DataSetIterator iter = new MultipleEpochsIterator(1,new RottenTomatoesTrainDataSetIterator(800,800,new RottenTomatoesWordVectorDataFetcher()));
    NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().nIn(iter.inputColumns()).hiddenUnit(RBM.HiddenUnit.RECTIFIED)
            .visibleUnit(RBM.VisibleUnit.GAUSSIAN).iterations(10).momentum(0.5f)
            .nOut(5).activationFunction(Activations.hardTanh()).learningRate(1e-6f).regularization(true)
            .l2(2e-4f).build();
    ActorNetworkRunner runner = new ActorNetworkRunner(iter);


    Conf c = new Conf();
    c.setConf(conf);
    c.setMultiLayerClazz(DBN.class);
    c.setLayerConfigs();
    c.setSplit(100);
    c.getLayerConfigs().get(c.getLayerConfigs().size() - 1).setActivationFunction(Activations.softMaxRows());
    c.getLayerConfigs().get(c.getLayerConfigs().size() - 1).setLossFunction(LossFunctions.LossFunction.MCXENT);

    c.setLayerSizes(new int[]{1000,500});

    runner.setup(c);
    runner.train();







}
 
开发者ID:ihuerga,项目名称:deeplearning4j-nlp-examples,代码行数:30,代码来源:MovingWindowDBN.java



注:本文中的org.deeplearning4j.datasets.iterator.DataSetIterator类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Java UiChild类代码示例发布时间:2022-05-22
下一篇:
Java FlowEvent类代码示例发布时间:2022-05-22
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap