• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Java Model类代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Java中org.deeplearning4j.nn.api.Model的典型用法代码示例。如果您正苦于以下问题:Java Model类的具体用法?Java Model怎么用?Java Model使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。



Model类属于org.deeplearning4j.nn.api包,在下文中一共展示了Model类的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。

示例1: fromFile

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
public static DLModel fromFile(File file) throws Exception {
	Model model = null;
	try {
		System.out.println("Trying to load file as computation graph: " + file);
		model = ModelSerializer.restoreComputationGraph(file);
		System.out.println("Loaded Computation Graph.");
	} catch (Exception e) {
		try {
			System.out.println("Failed to load computation graph. Trying to load model.");
			model = ModelSerializer.restoreMultiLayerNetwork(file);
			System.out.println("Loaded Multilayernetwork");
		} catch (Exception e1) {
			System.out.println("Give up trying to load file: " + file);
			throw e;
		}
	}
	return new DLModel(model);
}
 
开发者ID:jesuino,项目名称:java-ml-projects,代码行数:19,代码来源:DLModel.java


示例2: onEpochEnd

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void onEpochEnd(Model model) {
  currentEpoch++;

  // Skip if this is not an evaluation epoch
  if (currentEpoch % n != 0) {
    return;
  }

  String s = "Epoch [" + currentEpoch + "/" + numEpochs + "]\n";

  if (enableIntermediateEvaluations) {
    s += "Train Set:      \n" + evaluateDataSetIterator(model, trainIterator, true);
    if (validationIterator != null) {
      s += "Validation Set: \n" + evaluateDataSetIterator(model, validationIterator, false);
    }
  }

  log(s);
}
 
开发者ID:Waikato,项目名称:wekaDeeplearning4j,代码行数:21,代码来源:EpochListener.java


示例3: iterationDone

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(Model model, int i) {
    if (printIterations <= 0)
        printIterations = 1;
    if (iterCount % printIterations == 0) {
        iter.reset();
        double cost = 0;
        double count = 0;
        while(iter.hasNext()) {
            DataSet minibatch = iter.next(miniBatchSize);
            cost += ((MultiLayerNetwork)model).scoreExamples(minibatch, false).sumNumber().doubleValue();
            count += minibatch.getLabelsMaskArray().sumNumber().doubleValue();
        }
        log.info(String.format("Iteration %5d test set score: %.4f", iterCount, cost/count));
    }
    iterCount++;
}
 
开发者ID:jpatanooga,项目名称:strata-2016-nyc-dl4j-rnn,代码行数:18,代码来源:HeldoutScoreIterationListener.java


示例4: iterationDone

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(Model model, int iteration, int epoch) {
    //Check per-iteration termination conditions
    double latestScore = model.score();
    trainer.setLatestScore(latestScore);
    for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) {
        if (c.terminate(latestScore)) {
            trainer.setTermination(true);
            trainer.setTerminationReason(c);
            break;
        }
    }
    if (trainer.getTermination()) {
        // use built-in kill switch to stop fit operation
        wrapper.stopFit();
    }

    trainer.incrementIteration();
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:20,代码来源:EarlyStoppingParallelTrainer.java


示例5: testListenersForModel

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
private static void testListenersForModel(Model model, List<IterationListener> listeners) {

        int nWorkers = 2;
        ParallelWrapper wrapper = new ParallelWrapper.Builder(model).workers(nWorkers).averagingFrequency(1)
                        .reportScoreAfterAveraging(true).build();

        if (listeners != null) {
            wrapper.setListeners(listeners);
        }

        List<DataSet> data = new ArrayList<>();
        for (int i = 0; i < nWorkers; i++) {
            data.add(new DataSet(Nd4j.rand(1, 10), Nd4j.rand(1, 10)));
        }

        DataSetIterator iter = new ExistingDataSetIterator(data);

        TestListener.clearCounts();
        wrapper.fit(iter);

        assertEquals(2, TestListener.workerIDs.size());
        assertEquals(1, TestListener.sessionIDs.size());
        assertEquals(2, TestListener.forwardPassCount.get());
        assertEquals(2, TestListener.backwardPassCount.get());
    }
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:26,代码来源:TestListeners.java


示例6: updateGradientAccordingToParams

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void updateGradientAccordingToParams(Gradient gradient, Model model, int batchSize) {
    if (model instanceof ComputationGraph) {
        ComputationGraph graph = (ComputationGraph) model;
        if (computationGraphUpdater == null) {
            try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                computationGraphUpdater = new ComputationGraphUpdater(graph);
            }
        }
        computationGraphUpdater.update(gradient, getIterationCount(model), getEpochCount(model), batchSize);
    } else {
        if (updater == null) {
            try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
                updater = UpdaterCreator.getUpdater(model);
            }
        }
        Layer layer = (Layer) model;

        updater.update(layer, gradient, getIterationCount(model), getEpochCount(model), batchSize);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:22,代码来源:BaseOptimizer.java


示例7: onForwardPass

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void onForwardPass(Model model, Map<String, INDArray> activations) {
    int iterCount = getModelInfo(model).iterCount;
    if (calcFromActivations() && updateConfig.reportingFrequency() > 0
            && (iterCount == 0 || iterCount % updateConfig.reportingFrequency() == 0)) {
        if (updateConfig.collectHistograms(StatsType.Activations)) {
            activationHistograms = getHistograms(activations, updateConfig.numHistogramBins(StatsType.Activations));
        }
        if (updateConfig.collectMean(StatsType.Activations)) {
            meanActivations = calculateSummaryStats(activations, StatType.Mean);
        }
        if (updateConfig.collectStdev(StatsType.Activations)) {
            stdevActivations = calculateSummaryStats(activations, StatType.Stdev);
        }
        if (updateConfig.collectMeanMagnitudes(StatsType.Activations)) {
            meanMagActivations = calculateSummaryStats(activations, StatType.MeanMagnitude);
        }
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:20,代码来源:BaseStatsListener.java


示例8: onGradientCalculation

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void onGradientCalculation(Model model) {
    int iterCount = getModelInfo(model).iterCount;
    if (calcFromGradients() && updateConfig.reportingFrequency() > 0
            && (iterCount == 0 || iterCount % updateConfig.reportingFrequency() == 0)) {
        Gradient g = model.gradient();
        if (updateConfig.collectHistograms(StatsType.Gradients)) {
            gradientHistograms = getHistograms(g.gradientForVariable(), updateConfig.numHistogramBins(StatsType.Gradients));
        }

        if (updateConfig.collectMean(StatsType.Gradients)) {
            meanGradients = calculateSummaryStats(g.gradientForVariable(), StatType.Mean);
        }
        if (updateConfig.collectStdev(StatsType.Gradients)) {
            stdevGradient = calculateSummaryStats(g.gradientForVariable(), StatType.Stdev);
        }
        if (updateConfig.collectMeanMagnitudes(StatsType.Gradients)) {
            meanMagGradients = calculateSummaryStats(g.gradientForVariable(), StatType.MeanMagnitude);
        }
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:22,代码来源:BaseStatsListener.java


示例9: configureListeners

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
private void configureListeners(Model m, int counter) {
    if (iterationListeners != null) {
        List<IterationListener> list = new ArrayList<>(iterationListeners.size());
        for (IterationListener l : iterationListeners) {
            if (listenerRouterProvider != null && l instanceof RoutingIterationListener) {
                RoutingIterationListener rl = (RoutingIterationListener) l;
                rl.setStorageRouter(listenerRouterProvider.getRouter());
                String workerID = UIDProvider.getJVMUID() + "_" + counter;
                rl.setWorkerID(workerID);
            }
            list.add(l); //Don't need to clone listeners: not from broadcast, so deserialization handles
        }
        if (m instanceof MultiLayerNetwork)
            ((MultiLayerNetwork) m).setListeners(list);
        else
            ((ComputationGraph) m).setListeners(list);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:19,代码来源:ParameterAveragingTrainingWorker.java


示例10: testLoadNormalizersFile

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Test
public void testLoadNormalizersFile() throws Exception {
    MultiLayerNetwork net = getNetwork();

    File tempFile = File.createTempFile("tsfs", "fdfsdf");
    tempFile.deleteOnExit();

    ModelSerializer.writeModel(net, tempFile, true);

    NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1);
    normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2})));
    ModelSerializer.addNormalizerToModel(tempFile, normalizer);
    Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
    Normalizer<?> normalizer1 = ModelGuesser.loadNormalizer(tempFile.getAbsolutePath());
    assertEquals(model, net);
    assertEquals(normalizer, normalizer1);

}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:19,代码来源:ModelGuesserTest.java


示例11: testLoadNormalizersInputStream

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Test
public void testLoadNormalizersInputStream() throws Exception {
    MultiLayerNetwork net = getNetwork();

    File tempFile = File.createTempFile("tsfs", "fdfsdf");
    tempFile.deleteOnExit();

    ModelSerializer.writeModel(net, tempFile, true);

    NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1);
    normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2})));
    ModelSerializer.addNormalizerToModel(tempFile, normalizer);
    Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
    try (InputStream inputStream = new FileInputStream(tempFile)) {
        Normalizer<?> normalizer1 = ModelGuesser.loadNormalizer(inputStream);
        assertEquals(model, net);
        assertEquals(normalizer, normalizer1);
    }

}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:ModelGuesserTest.java


示例12: iterationDone

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(final Model model, final int iteration) {
    runOnUiThread(new Runnable() {
        @Override
        public void run() {
            if (iteration % 100 == 0) {
                double result = model.score();
                String message = "\nScore at iteration " + iteration + " is " + result;
                Log.d(TAG, message);

                loggingArea.append(message);
            }
        }
    });
}
 
开发者ID:mccorby,项目名称:FederatedAndroidTrainer,代码行数:16,代码来源:MainActivity.java


示例13: iterationDone

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(Model model, int iteration) {
    if(m_printIterations <= 0)
        m_printIterations = 1;
    if(m_iterCount % m_printIterations == 0) {
        invoke();
        double result = model.score();
        m_progressBar.printProgress("Iteration: " + m_iterCount + ", Score: " + result);
    }
    m_iterCount++;
}
 
开发者ID:braeunlich,项目名称:anagnostes,代码行数:12,代码来源:TrainProgressIterationListener.java


示例14: iterationDone

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone (Model model,
                           int iteration)
{
    iterCount++;

    if ((iterCount % constants.listenerPeriod.getValue()) == 0) {
        invoke();

        final double score = model.score();
        final int count = (int) iterCount;
        logger.info(String.format("Score at iteration %d is %.5f", count, score));
        display(epoch, count, score);
    }
}
 
开发者ID:Audiveris,项目名称:audiveris,代码行数:16,代码来源:TrainingPanel.java


示例15: iterationDone

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(Model model, int i) {
    if (printIterations <= 0)
        printIterations = 1;
    if (iterCount % printIterations == 0) {
        saveModel((MultiLayerNetwork)model, this.modelSavePath);
    }
    iterCount++;
}
 
开发者ID:jpatanooga,项目名称:strata-2016-nyc-dl4j-rnn,代码行数:10,代码来源:ModelSaver.java


示例16: iterationDone

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
public void iterationDone(Model model, int i) {
    if(printIterations <= 0)
        printIterations = 1;
    if (iterCount % printIterations == 0) {
        invoke();
        String[] samples = sampleBeerRatingFromNetwork(net, reader, rng, temperature, maxCharactersToSample, 1, styleIndex);

        System.out.println("----- Generating Lager Beer Review Samples -----");
        for (int j = 0; j < samples.length; j++) {
            System.out.println("SAMPLE " + j + ": " + samples[j]);
        }
    }
    iterCount++;
}
 
开发者ID:jpatanooga,项目名称:strata-2016-nyc-dl4j-rnn,代码行数:16,代码来源:SampleGeneratorListener.java


示例17: InferenceWorker

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
private InferenceWorker(int id, @NonNull Model model, @NonNull BlockingQueue inputQueue, boolean rootDevice) {
    this.inputQueue = inputQueue;
    this.protoModel = model;
    this.rootDevice = rootDevice;

    this.setDaemon(true);
    this.setName("InferenceThread-" + id);

}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:10,代码来源:ParallelInference.java


示例18: scoreMinibatch

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
@Override
protected double scoreMinibatch(Model network, INDArray[] features, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, INDArray[] output) {
    if(network instanceof MultiLayerNetwork){
        return ((MultiLayerNetwork) network).score(new DataSet(get0(features), get0(labels), get0(fMask), get0(lMask)), false)
                * features[0].size(0);
    } else if(network instanceof ComputationGraph){
        return ((ComputationGraph) network).score(new MultiDataSet(features, labels, fMask, lMask))
                * features[0].size(0);
    } else {
        throw new RuntimeException("Unknown model type: " + network.getClass());
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:13,代码来源:DataSetLossCalculator.java


示例19: updateExamplesMinibatchesCounts

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
private void updateExamplesMinibatchesCounts(Model model) {
    ModelInfo modelInfo = getModelInfo(model);
    int examplesThisMinibatch = 0;
    if (model instanceof MultiLayerNetwork) {
        examplesThisMinibatch = ((MultiLayerNetwork) model).batchSize();
    } else if (model instanceof ComputationGraph) {
        examplesThisMinibatch = ((ComputationGraph) model).batchSize();
    } else if (model instanceof Layer) {
        examplesThisMinibatch = ((Layer) model).getInputMiniBatchSize();
    }
    modelInfo.examplesSinceLastReport += examplesThisMinibatch;
    modelInfo.totalExamples += examplesThisMinibatch;
    modelInfo.minibatchesSinceLastReport++;
    modelInfo.totalMinibatches++;
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:16,代码来源:BaseStatsListener.java


示例20: getUpdater

import org.deeplearning4j.nn.api.Model; //导入依赖的package包/类
public static org.deeplearning4j.nn.api.Updater getUpdater(Model layer) {
    if (layer instanceof MultiLayerNetwork) {
        return new MultiLayerUpdater((MultiLayerNetwork) layer);
    } else if (layer instanceof ComputationGraph) {
        return new ComputationGraphUpdater((ComputationGraph) layer);
    } else {
        return new LayerUpdater((Layer) layer);
    }
}
 
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:10,代码来源:UpdaterCreator.java



注:本文中的org.deeplearning4j.nn.api.Model类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Java BulletGlobals类代码示例发布时间:2022-05-23
下一篇:
Java AutofillPopup类代码示例发布时间:2022-05-23
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap