• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Java SavedModelBundle类代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Java中org.tensorflow.SavedModelBundle的典型用法代码示例。如果您正苦于以下问题:Java SavedModelBundle类的具体用法?Java SavedModelBundle怎么用?Java SavedModelBundle使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。



SavedModelBundle类属于org.tensorflow包,在下文中一共展示了SavedModelBundle类的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Java代码示例。

示例1: loadModel

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
@Override
public SavedModelBundle loadModel(final Location source,
	final String modelName, final String... tags) throws IOException
{
	final String key = modelName + "/" + Arrays.toString(tags);

	// If the model is already cached in memory, return it.
	if (models.containsKey(key)) return models.get(key);

	// Get a local directory with unpacked model data.
	final File modelDir = modelDir(source, modelName);

	// Load the saved model.
	final SavedModelBundle model = //
		SavedModelBundle.load(modelDir.getAbsolutePath(), tags);

	return model;
}
 
开发者ID:imagej,项目名称:imagej-tensorflow,代码行数:19,代码来源:DefaultTensorFlowService.java


示例2: dispose

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
@Override
public void dispose() {
	// Dispose models.
	for (final SavedModelBundle model : models.values()) {
		model.close();
	}
	models.clear();

	// Dispose graphs.
	for (final Graph graph : graphs.values()) {
		graph.close();
	}
	graphs.clear();

	// Dispose labels.
	labelses.clear();
}
 
开发者ID:imagej,项目名称:imagej-tensorflow,代码行数:18,代码来源:DefaultTensorFlowService.java


示例3: importGraph

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
private TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle model) {
    TensorFlowModel result = new TensorFlowModel();
    for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
        TensorFlowModel.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName"

        importInputs(signatureEntry.getValue().getInputsMap(), signature);
        for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
            String outputName = output.getKey();
            try {
                NodeDef node = getNode(nameOf(output.getValue().getName()), graph.getGraphDef());
                importNode(node, graph.getGraphDef(), model, result);
                signature.output(outputName, nameOf(output.getValue().getName()));
            }
            catch (IllegalArgumentException e) {
                signature.skippedOutput(outputName, Exceptions.toMessageString(e));
            }
        }
    }
    return result;
}
 
开发者ID:vespa-engine,项目名称:vespa,代码行数:21,代码来源:TensorFlowImporter.java


示例4: tensorFunctionOf

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) {
    // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops
    // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/
    switch (tfNode.getOp().toLowerCase()) {
        case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add());
        case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos());
        case "elu": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.elu());
        case "identity" : return operationMapper.identity(tfNode, model, result);
        case "placeholder" : return operationMapper.placeholder(tfNode, result);
        case "relu": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.relu());
        case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result));
        case "sigmoid": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.sigmoid());
        case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result));
        default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported");
    }
}
 
开发者ID:vespa-engine,项目名称:vespa,代码行数:17,代码来源:TensorFlowImporter.java


示例5: identity

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) {
    if ( ! tfNode.getName().endsWith("/read"))
        throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " +
                                           "nodes are only supported when reading variables");
    if (tfNode.getInputList().size() != 1)
        throw new IllegalArgumentException("A Variable/read node must have one input but has " +
                                           tfNode.getInputList().size());

    String name = tfNode.getInput(0);
    AttrValue shapes = tfNode.getAttrMap().get("_output_shapes");
    if (shapes == null)
        throw new IllegalArgumentException("Referenced variable '" + name + "' is missing a tensor output shape");
    Session.Runner fetched = model.session().runner().fetch(name);
    List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
    if ( importedTensors.size() != 1)
        throw new IllegalStateException("Expected 1 tensor from reading Variable " + name + ", but got " +
                                        importedTensors.size());
    Tensor constant = tensorConverter.toVespaTensor(importedTensors.get(0));
    result.constant(name, constant);
    return new TypedTensorFunction(constant.type(),
                                   new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(" + name + ")")));
}
 
开发者ID:vespa-engine,项目名称:vespa,代码行数:23,代码来源:OperationMapper.java


示例6: testLoadModel

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
public void testLoadModel() throws Exception {
    String modelDir = "examples/tensorflow/estimator/model";
    SavedModelBundle bundle = SavedModelBundle.load(modelDir + "/" + SpongeUtils.getLastSubdirectory(modelDir), "serve");

    try (Session s = bundle.session()/* ; Tensor output = s.runner().fetch("MyConst").run().get(0) */) {
        Tensor x = Tensor.create(new float[] { 2, 5, 8, 1 });
        Tensor y = s.runner().feed("x", x).fetch("y").run().get(0);

        logger.info("y = {}", y.floatValue());
    }
}
 
开发者ID:softelnet,项目名称:sponge,代码行数:12,代码来源:TensorflowTest.java


示例7: importModel

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
/**
 * Imports a saved TensorFlow model from a directory.
 * The model should be saved as a .pbtxt or .pb file.
 * The name of the model is taken as the db/pbtxt file name (not including the file ending).
 *
 * @param modelDir the directory containing the TensorFlow model files to import
 */
public TensorFlowModel importModel(String modelDir) {
    try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
        return importModel(model);
    }
    catch (IllegalArgumentException e) {
        throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
    }
}
 
开发者ID:vespa-engine,项目名称:vespa,代码行数:16,代码来源:TensorFlowImporter.java


示例8: importNode

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
/** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) {
    TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result);
    // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output
    // will be used
    result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), new TensorFunctionNode(function.function())));
    return function;
}
 
开发者ID:vespa-engine,项目名称:vespa,代码行数:9,代码来源:TensorFlowImporter.java


示例9: assertEqualResult

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
private void assertEqualResult(SavedModelBundle model, TensorFlowModel result, String inputName, String operationName) {
    Tensor tfResult = tensorFlowExecute(model, inputName, operationName);
    Context context = contextFrom(result);
    Tensor placeholder = placeholderArgument();
    context.put(inputName, new TensorValue(placeholder));
    Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor();
    assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult);
}
 
开发者ID:vespa-engine,项目名称:vespa,代码行数:9,代码来源:TensorflowImportTestCase.java


示例10: tensorFlowExecute

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) {
    Session.Runner runner = model.session().runner();
    org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784));
    runner.feed(inputName, placeholder);
    List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
    assertEquals(1, results.size());
    return new TensorConverter().toVespaTensor(results.get(0));
}
 
开发者ID:vespa-engine,项目名称:vespa,代码行数:9,代码来源:TensorflowImportTestCase.java


示例11: run

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
@Override
public void run() {
	try {
		validateFormat(originalImage);
		RandomAccessibleInterval<FloatType> normalizedImage = normalize(originalImage);

		final long loadModelStart = System.nanoTime();
		final HTTPLocation source = new HTTPLocation(MODEL_URL);
		final SavedModelBundle model = //
			tensorFlowService.loadModel(source, MODEL_NAME, MODEL_TAG);
		final long loadModelEnd = System.nanoTime();
		log.info(String.format(
			"Loaded microscope focus image quality model in %dms", (loadModelEnd -
				loadModelStart) / 1000000));

		// Extract names from the model signature.
		// The strings "input", "probabilities" and "patches" are meant to be
		// in sync with the model exporter (export_saved_model()) in Python.
		final SignatureDef sig = MetaGraphDef.parseFrom(model.metaGraphDef())
			.getSignatureDefOrThrow(DEFAULT_SERVING_SIGNATURE_DEF_KEY);
		try (final Tensor inputTensor = Tensors.tensor(normalizedImage)) {
			// Run the model.
			final long runModelStart = System.nanoTime();
			final List<Tensor> fetches = model.session().runner() //
				.feed(opName(sig.getInputsOrThrow("input")), inputTensor) //
				.fetch(opName(sig.getOutputsOrThrow("probabilities"))) //
				.fetch(opName(sig.getOutputsOrThrow("patches"))) //
				.run();
			final long runModelEnd = System.nanoTime();
			log.info(String.format("Ran image through model in %dms", //
				(runModelEnd - runModelStart) / 1000000));

			// Process the results.
			try (final Tensor probabilities = fetches.get(0);
					final Tensor patches = fetches.get(1))
			{
				processPatches(probabilities, patches);
			}
		}
	}
	catch (final Exception exc) {
		// Use the LogService to report the error.
		log.error(exc);
	}
}
 
开发者ID:fiji,项目名称:microscope-image-quality,代码行数:46,代码来源:MicroscopeImageFocusQualityClassifier.java


示例12: close

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
@Override
public void close(){
	SavedModelBundle bundle = getBundle();

	bundle.close();
}
 
开发者ID:jpmml,项目名称:jpmml-tensorflow,代码行数:7,代码来源:SavedModel.java


示例13: getSession

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
public Session getSession(){
	SavedModelBundle bundle = getBundle();

	return bundle.session();
}
 
开发者ID:jpmml,项目名称:jpmml-tensorflow,代码行数:6,代码来源:SavedModel.java


示例14: getGraph

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
public Graph getGraph(){
	SavedModelBundle bundle = getBundle();

	return bundle.graph();
}
 
开发者ID:jpmml,项目名称:jpmml-tensorflow,代码行数:6,代码来源:SavedModel.java


示例15: getBundle

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
public SavedModelBundle getBundle(){
	return this.bundle;
}
 
开发者ID:jpmml,项目名称:jpmml-tensorflow,代码行数:4,代码来源:SavedModel.java


示例16: setBundle

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
private void setBundle(SavedModelBundle bundle){
	this.bundle = bundle;
}
 
开发者ID:jpmml,项目名称:jpmml-tensorflow,代码行数:4,代码来源:SavedModel.java


示例17: createBatch

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
@Override
protected ArchiveBatch createBatch(String name, String dataset, Predicate<FieldName> predicate){
	ArchiveBatch result = new IntegrationTestBatch(name, dataset, predicate){

		@Override
		public IntegrationTest getIntegrationTest(){
			return EstimatorTest.this;
		}

		@Override
		public PMML getPMML() throws Exception {
			File savedModelDir = getSavedModelDir();

			SavedModelBundle bundle = SavedModelBundle.load(savedModelDir.getAbsolutePath(), "serve");

			try(SavedModel savedModel = new SavedModel(bundle)){
				EstimatorFactory estimatorFactory = EstimatorFactory.newInstance();

				Estimator estimator = estimatorFactory.newEstimator(savedModel);

				PMML pmml = estimator.encodePMML();

				ensureValidity(pmml);

				return pmml;
			}
		}

		private File getSavedModelDir() throws IOException, URISyntaxException {
			ClassLoader classLoader = (EstimatorTest.this.getClass()).getClassLoader();

			String protoPath = ("savedmodel/" + getName() + getDataset() + "/saved_model.pbtxt");

			URL protoResource = classLoader.getResource(protoPath);
			if(protoResource == null){
				throw new NoSuchFileException(protoPath);
			}

			File protoFile = (Paths.get(protoResource.toURI())).toFile();

			return protoFile.getParentFile();
		}
	};

	return result;
}
 
开发者ID:jpmml,项目名称:jpmml-tensorflow,代码行数:47,代码来源:EstimatorTest.java


示例18: importArguments

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model,
                                                  TensorFlowModel result) {
    return tfNode.getInputList().stream()
                                .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result))
                                .collect(Collectors.toList());
}
 
开发者ID:vespa-engine,项目名称:vespa,代码行数:7,代码来源:TensorFlowImporter.java


示例19: testMnistSoftmaxImport

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
@Test
public void testMnistSoftmaxImport() {
    String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved";
    SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
    TensorFlowModel result = new TensorFlowImporter().importModel(model);

    // Check constants
    assertEquals(2, result.constants().size());

    Tensor constant0 = result.constants().get("Variable");
    assertNotNull(constant0);
    assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
                 constant0.type());
    assertEquals(7840, constant0.size());

    Tensor constant1 = result.constants().get("Variable_1");
    assertNotNull(constant1);
    assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
                 constant1.type());
    assertEquals(10, constant1.size());

    // Check signatures
    assertEquals(1, result.signatures().size());
    TensorFlowModel.Signature signature = result.signatures().get("serving_default");
    assertNotNull(signature);

    // ... signature inputs
    assertEquals(1, signature.inputs().size());
    TensorType argument0 = signature.inputArgument("x");
    assertNotNull(argument0);
    assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);

    // ... signature outputs
    assertEquals(1, signature.outputs().size());
    RankingExpression output = signature.outputExpression("y");
    assertNotNull(output);
    assertEquals("add", output.getName());
    assertEquals("" +
                 "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " +
                 "rename(constant(Variable_1), d0, d1), " +
                 "f(a,b)(a + b))",
                 toNonPrimitiveString(output));

    // Test execution
    assertEqualResult(model, result, "Placeholder", "Variable/read");
    assertEqualResult(model, result, "Placeholder", "Variable_1/read");
    assertEqualResult(model, result, "Placeholder", "MatMul");
    assertEqualResult(model, result, "Placeholder", "add");
}
 
开发者ID:vespa-engine,项目名称:vespa,代码行数:50,代码来源:TensorflowImportTestCase.java


示例20: loadModel

import org.tensorflow.SavedModelBundle; //导入依赖的package包/类
/**
 * Extracts a persisted model from the given location.
 * 
 * @param source The location of the model, which must be structured as a ZIP
 *          archive.
 * @param modelName The name of the model by which the source should be
 *          unpacked and cached as needed.
 * @param tags Optional list of tags passed to
 *          {@link SavedModelBundle#load(String, String...)}.
 * @return The extracted TensorFlow {@link SavedModelBundle} object.
 * @throws IOException If something goes wrong reading or unpacking the
 *           archive.
 */
SavedModelBundle loadModel(Location source, String modelName, String... tags)
	throws IOException;
 
开发者ID:imagej,项目名称:imagej-tensorflow,代码行数:16,代码来源:TensorFlowService.java



注:本文中的org.tensorflow.SavedModelBundle类示例整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Java NMMemoryStateStoreService类代码示例发布时间:2022-05-22
下一篇:
Java CorruptionException类代码示例发布时间:2022-05-22
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap