本文整理汇总了Java中org.deeplearning4j.api.storage.StatsStorage类的典型用法代码示例。如果您正苦于以下问题:Java StatsStorage类的具体用法?Java StatsStorage怎么用?Java StatsStorage使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
StatsStorage类属于org.deeplearning4j.api.storage包,在下文中一共展示了StatsStorage类的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。
示例1: getDefaultSession
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
private void getDefaultSession() {
if (currentSessionID != null)
return;
long mostRecentTime = Long.MIN_VALUE;
String sessionID = null;
for (Map.Entry<String, StatsStorage> entry : knownSessionIDs.entrySet()) {
List<Persistable> staticInfos = entry.getValue().getAllStaticInfos(entry.getKey(), StatsListener.TYPE_ID);
if (staticInfos == null || staticInfos.isEmpty())
continue;
Persistable p = staticInfos.get(0);
long thisTime = p.getTimeStamp();
if (thisTime > mostRecentTime) {
mostRecentTime = thisTime;
sessionID = entry.getKey();
}
}
if (sessionID != null) {
currentSessionID = sessionID;
}
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:23,代码来源:TrainModule.java
示例2: getModelGraph
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
private Result getModelGraph() {
boolean noData = currentSessionID == null;
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST
: ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
if (allStatic.isEmpty()) {
return ok();
}
TrainModuleUtils.GraphInfo gi = getGraphInfo();
if (gi == null)
return ok();
return ok(Json.toJson(gi));
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:18,代码来源:TrainModule.java
示例3: attach
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public synchronized void attach(StatsStorage statsStorage) {
if (statsStorage == null)
throw new IllegalArgumentException("StatsStorage cannot be null");
if (statsStorageInstances.contains(statsStorage))
return;
StatsStorageListener listener = new QueueStatsStorageListener(eventQueue);
listeners.add(new Pair<>(statsStorage, listener));
statsStorage.registerStatsStorageListener(listener);
statsStorageInstances.add(statsStorage);
for (UIModule uiModule : uiModules) {
uiModule.onAttach(statsStorage);
}
log.info("StatsStorage instance attached to UI: {}", statsStorage);
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:18,代码来源:PlayUIServer.java
示例4: detach
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public synchronized void detach(StatsStorage statsStorage) {
if (statsStorage == null)
throw new IllegalArgumentException("StatsStorage cannot be null");
if (!statsStorageInstances.contains(statsStorage))
return; //No op
boolean found = false;
for (Iterator<Pair<StatsStorage, StatsStorageListener>> iterator = listeners.iterator(); iterator.hasNext();) {
Pair<StatsStorage, StatsStorageListener> p = iterator.next();
if (p.getFirst() == statsStorage) { //Same object, not equality
statsStorage.deregisterStatsStorageListener(p.getSecond());
iterator.remove();
found = true;
}
}
for (UIModule uiModule : uiModules) {
uiModule.onDetach(statsStorage);
}
if (found) {
log.info("StatsStorage instance detached from UI: {}", statsStorage);
}
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:23,代码来源:PlayUIServer.java
示例5: testListenersViaModel
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Test
public void testListenersViaModel() {
TestListener.clearCounts();
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list().layer(0,
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(10).nOut(10)
.activation(Activation.TANH).build());
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
StatsStorage ss = new InMemoryStatsStorage();
model.setListeners(new TestListener(), new StatsListener(ss));
testListenersForModel(model, null);
assertEquals(1, ss.listSessionIDs().size());
assertEquals(2, ss.listWorkerIDsForSession(ss.listSessionIDs().get(0)).size());
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:21,代码来源:TestListeners.java
示例6: testListenersViaModelGraph
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Test
public void testListenersViaModelGraph() {
TestListener.clearCounts();
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder()
.addInputs("in").addLayer("0",
new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(10).nOut(10)
.activation(Activation.TANH).build(),
"in")
.setOutputs("0").build();
ComputationGraph model = new ComputationGraph(conf);
model.init();
StatsStorage ss = new InMemoryStatsStorage();
model.setListeners(new TestListener(), new StatsListener(ss));
testListenersForModel(model, null);
assertEquals(1, ss.listSessionIDs().size());
assertEquals(2, ss.listWorkerIDsForSession(ss.listSessionIDs().get(0)).size());
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:23,代码来源:TestListeners.java
示例7: onAttach
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public synchronized void onAttach(StatsStorage statsStorage) {
for (String sessionID : statsStorage.listSessionIDs()) {
for (String typeID : statsStorage.listTypeIDsForSession(sessionID)) {
if (!StatsListener.TYPE_ID.equals(typeID))
continue;
knownSessionIDs.put(sessionID, statsStorage);
}
}
if (currentSessionID == null)
getDefaultSession();
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:14,代码来源:TrainModule.java
示例8: onDetach
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public void onDetach(StatsStorage statsStorage) {
for (String s : knownSessionIDs.keySet()) {
if (knownSessionIDs.get(s) == statsStorage) {
knownSessionIDs.remove(s);
}
}
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:9,代码来源:TrainModule.java
示例9: getWorkerIdForIndex
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
private synchronized String getWorkerIdForIndex(int workerIdx) {
String sid = currentSessionID;
if (sid == null)
return null;
Map<Integer, String> idxToId = workerIdxToName.get(sid);
if (idxToId == null) {
idxToId = Collections.synchronizedMap(new HashMap<>());
workerIdxToName.put(sid, idxToId);
}
if (idxToId.containsKey(workerIdx)) {
return idxToId.get(workerIdx);
}
//Need to record new worker...
//Get counter
AtomicInteger counter = workerIdxCount.get(sid);
if (counter == null) {
counter = new AtomicInteger(0);
workerIdxCount.put(sid, counter);
}
//Get all worker IDs
StatsStorage ss = knownSessionIDs.get(sid);
List<String> allWorkerIds = new ArrayList<>(ss.listWorkerIDsForSessionAndType(sid, StatsListener.TYPE_ID));
Collections.sort(allWorkerIds);
//Ensure all workers have been assigned an index
for (String s : allWorkerIds) {
if (idxToId.containsValue(s))
continue;
//Unknown worker ID:
idxToId.put(counter.getAndIncrement(), s);
}
//May still return null if index is wrong/too high...
return idxToId.get(workerIdx);
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:40,代码来源:TrainModule.java
示例10: enableRemoteListener
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public void enableRemoteListener(StatsStorageRouter statsStorage, boolean attach) {
remoteReceiverModule.setEnabled(true);
remoteReceiverModule.setStatsStorage(statsStorage);
if (attach && statsStorage instanceof StatsStorage) {
attach((StatsStorage) statsStorage);
}
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:9,代码来源:PlayUIServer.java
示例11: testUIMultipleSessions
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Test
@Ignore
public void testUIMultipleSessions() throws Exception {
for (int session = 0; session < 3; session++) {
StatsStorage ss = new InMemoryStatsStorage();
UIServer uiServer = UIServer.getInstance();
uiServer.attach(ss);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build())
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(4).nOut(3).build())
.pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 20; i++) {
net.fit(iter);
Thread.sleep(100);
}
}
Thread.sleep(1000000);
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:34,代码来源:TestPlayUI.java
示例12: testUICompGraph
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Test
@Ignore
public void testUICompGraph() throws Exception {
StatsStorage ss = new InMemoryStatsStorage();
UIServer uiServer = UIServer.getInstance();
uiServer.attach(ss);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in")
.addLayer("L0", new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build(),
"in")
.addLayer("L1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(4).nOut(3).build(), "L0")
.pretrain(false).backprop(true).setOutputs("L1").build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
net.setListeners(new StatsListener(ss), new ScoreIterationListener(1));
DataSetIterator iter = new IrisDataSetIterator(150, 150);
for (int i = 0; i < 100; i++) {
net.fit(iter);
Thread.sleep(100);
}
Thread.sleep(100000);
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:31,代码来源:TestPlayUI.java
示例13: ConvolutionalIterationListener
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
public ConvolutionalIterationListener(StatsStorageRouter ssr, int iterations, boolean openBrowser, String sessionID,
String workerID) {
this.ssr = ssr;
if (sessionID == null) {
//TODO handle syncing session IDs across different listeners in the same model...
this.sessionID = UUID.randomUUID().toString();
} else {
this.sessionID = sessionID;
}
if (workerID == null) {
this.workerID = UIDProvider.getJVMUID() + "_" + Thread.currentThread().getId();
} else {
this.workerID = workerID;
}
String subPath = "activations";
this.freq = iterations;
this.openBrowser = openBrowser;
path = "http://localhost:" + UIServer.getInstance().getPort() + "/" + subPath;
if (openBrowser && ssr instanceof StatsStorage) {
UIServer.getInstance().attach((StatsStorage) ssr);
}
System.out.println("ConvolutionIterationListener path: " + path);
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:28,代码来源:ConvolutionalIterationListener.java
示例14: testParallelStatsListenerCompatibility
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Test
@Ignore //To be run manually
public void testParallelStatsListenerCompatibility() throws Exception {
UIServer uiServer = UIServer.getInstance();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Sgd()).weightInit(WeightInit.XAVIER).list()
.layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build())
.layer(1, new OutputLayer.Builder().nIn(3).nOut(3)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
// it's important that the UI can report results from parallel training
// there's potential for StatsListener to fail if certain properties aren't set in the model
StatsStorage statsStorage = new InMemoryStatsStorage();
net.setListeners(new StatsListener(statsStorage));
uiServer.attach(statsStorage);
DataSetIterator irisIter = new IrisDataSetIterator(50, 500);
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf =
new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>()
.epochTerminationConditions(new MaxEpochsTerminationCondition(500))
.scoreCalculator(new DataSetLossCalculator(irisIter, true))
.evaluateEveryNEpochs(2).modelSaver(saver).build();
IEarlyStoppingTrainer<MultiLayerNetwork> trainer =
new EarlyStoppingParallelTrainer<>(esConf, net, irisIter, null, 3, 6, 2);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
System.out.println(result);
assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:37,代码来源:TestParallelEarlyStoppingUI.java
示例15: main
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
public static void main(String[] args) {
UIServer server = UIServer.getInstance();
StatsStorage statsStorage = new InMemoryStatsStorage();
server.attach(statsStorage);
server.enableRemoteListener();
}
开发者ID:buybrain,项目名称:docker-dl4j-ui,代码行数:7,代码来源:Server.java
示例16: LSTMTrainer
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
/**
* Constructor
* @param trainingSet Text file containing several ABC music files
* @throws IOException
*/
public LSTMTrainer(String trainingSet, int seed) throws IOException {
lstmLayerSize_ = 200; // original 200
batchSize_ = 32; // original 32
truncatedBackPropThroughTimeLength_ = 50;
nbEpochs_ = 100;
learningRate_ = 0.04; // 0.1 original // best 0.05 3epochs
generateSamplesEveryNMinibatches_ = 200;
generationInitialization_ = "X";
seed_ = seed;
random_ = new Random(seed);
output_ = null;
trainingSetIterator_ = new ABCIterator(trainingSet, Charset.forName("ASCII"), batchSize_, random_);
charToInt_ = trainingSetIterator_.getCharToInt();
intToChar_ = trainingSetIterator_.getIntToChar();
exampleLength_ = trainingSetIterator_.getExampleLength();
int nOut = trainingSetIterator_.totalOutcomes();
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1)
.learningRate(learningRate_)
.rmsDecay(0.95) // 0.95 original
.seed(seed_)
.regularization(true) // true original
.l2(0.001)
.weightInit(WeightInit.XAVIER)
.updater(Updater.RMSPROP)
.list()
.layer(0, new GravesLSTM.Builder().nIn(trainingSetIterator_.inputColumns()).nOut(lstmLayerSize_)
.activation("tanh").build())
.layer(1, new GravesLSTM.Builder().nIn(lstmLayerSize_).nOut(lstmLayerSize_)
.activation("tanh").build())
.layer(2, new GravesLSTM.Builder().nIn(lstmLayerSize_).nOut(lstmLayerSize_)
.activation("tanh").build())
.layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation("softmax")
.nIn(lstmLayerSize_).nOut(nOut).build())
.backpropType(BackpropType.TruncatedBPTT)
.tBPTTForwardLength(truncatedBackPropThroughTimeLength_)
.tBPTTBackwardLength(truncatedBackPropThroughTimeLength_)
.pretrain(false).backprop(true)
.build();
lstmNet_ = new MultiLayerNetwork(conf);
lstmNet_.init();
//lstmNet_.setListeners(new ScoreIterationListener(1));
//lstmNet_.setListeners(new HistogramIterationListener(1));
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new InMemoryStatsStorage();
uiServer.attach(statsStorage);
lstmNet_.setListeners(new StatsListener(statsStorage));
if (ExecutionParameters.verbose) {
Layer[] layers = lstmNet_.getLayers();
int totalNumParams = 0;
for (int i = 0; i < layers.length; i++) {
int nParams = layers[i].numParams();
System.out.println("Number of parameters in layer " + i + ": " + nParams);
totalNumParams += nParams;
}
System.out.println("Total number of network parameters: " + totalNumParams);
}
}
开发者ID:paveyry,项目名称:LyreLand,代码行数:69,代码来源:LSTMTrainer.java
示例17: train
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
private static void train(CommandLine c) {
int nEpochs = Integer.parseInt(c.getOptionValue("e"));
String modelName = c.getOptionValue("o");
DataIterator<NormalizerStandardize> it = DataIterator.irisCsv(c.getOptionValue("i"));
RecordReaderDataSetIterator trainData = it.getIterator();
NormalizerStandardize normalizer = it.getNormalizer();
log.info("Data Loaded");
MultiLayerConfiguration conf = net(4, 3);
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
UIServer uiServer = UIServer.getInstance();
StatsStorage statsStorage = new InMemoryStatsStorage();
uiServer.attach(statsStorage);
model.setListeners(Arrays.asList(new ScoreIterationListener(1), new StatsListener(statsStorage)));
for (int i = 0; i < nEpochs; i++) {
log.info("Starting epoch {} of {}", i, nEpochs);
while (trainData.hasNext()) {
model.fit(trainData.next());
}
log.info("Finished epoch {}", i);
trainData.reset();
}
try {
ModelSerializer.writeModel(model, modelName, true);
normalizer.save(
new File(modelName + ".norm1"),
new File(modelName + ".norm2"),
new File(modelName + ".norm3"),
new File(modelName + ".norm4")
);
} catch (IOException e) {
e.printStackTrace();
}
log.info("Model saved to: {}", modelName);
}
开发者ID:wmeddie,项目名称:dl4j-trainer-archetype,代码行数:47,代码来源:Train.java
示例18: getSystemData
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
public Result getSystemData() {
Long lastUpdate = lastUpdateForSession.get(currentSessionID);
if (lastUpdate == null)
lastUpdate = -1L;
I18N i18n = I18NProvider.getInstance();
//First: get the MOST RECENT update...
//Then get all updates from most recent - 5 minutes -> TODO make this configurable...
boolean noData = currentSessionID == null;
StatsStorage ss = (noData ? null : knownSessionIDs.get(currentSessionID));
List<Persistable> allStatic = (noData ? Collections.EMPTY_LIST
: ss.getAllStaticInfos(currentSessionID, StatsListener.TYPE_ID));
List<Persistable> latestUpdates = (noData ? Collections.EMPTY_LIST
: ss.getLatestUpdateAllWorkers(currentSessionID, StatsListener.TYPE_ID));
long lastUpdateTime = -1;
if (latestUpdates == null || latestUpdates.isEmpty()) {
noData = true;
} else {
for (Persistable p : latestUpdates) {
lastUpdateTime = Math.max(lastUpdateTime, p.getTimeStamp());
}
}
long fromTime = lastUpdateTime - 5 * 60 * 1000; //TODO Make configurable
List<Persistable> lastNMinutes =
(noData ? null : ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, fromTime));
Map<String, Object> mem = getMemory(allStatic, lastNMinutes, i18n);
Pair<Map<String, Object>, Map<String, Object>> hwSwInfo = getHardwareSoftwareInfo(allStatic, i18n);
Map<String, Object> ret = new HashMap<>();
ret.put("updateTimestamp", lastUpdate);
ret.put("memory", mem);
ret.put("hardware", hwSwInfo.getFirst());
ret.put("software", hwSwInfo.getSecond());
return ok(Json.toJson(ret));
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:45,代码来源:TrainModule.java
示例19: isAttached
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public boolean isAttached(StatsStorage statsStorage) {
return statsStorageInstances.contains(statsStorage);
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:5,代码来源:PlayUIServer.java
示例20: getStatsStorageInstances
import org.deeplearning4j.api.storage.StatsStorage; //导入依赖的package包/类
@Override
public List<StatsStorage> getStatsStorageInstances() {
return new ArrayList<>(statsStorageInstances);
}
开发者ID:deeplearning4j,项目名称:deeplearning4j,代码行数:5,代码来源:PlayUIServer.java
注:本文中的org.deeplearning4j.api.storage.StatsStorage类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论