• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python tensorflow.import_graph_def函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.import_graph_def函数的典型用法代码示例。如果您正苦于以下问题:Python import_graph_def函数的具体用法?Python import_graph_def怎么用?Python import_graph_def使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了import_graph_def函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: loadmodle

def loadmodle():
	print u"step2:模型加载测试".decode('utf8')
	with tf.Session() as persisted_sess:
		print("---1:load graph") #加载计算图
		with gfile.FastGFile("/tmp/load/test.pb",'rb') as f:
			graph_def = tf.GraphDef()
			graph_def.ParseFromString(f.read())
			persisted_sess.graph.as_default()
			tf.import_graph_def(graph_def, name='') #加载图定义

		print("---2,map variables")
		persisted_result = persisted_sess.graph.get_tensor_by_name("saved1_result:0") #获取这个tensor
		tf.add_to_collection(tf.GraphKeys.VARIABLES,persisted_result)  				 #将这个tensor加入到要恢复的变量中

		# 恢复数据
		print("---3,load data")
		try:
			saver = tf.train.Saver(tf.all_variables()) # 'Saver' misnomer! Better: Persister!  #将变量恢复
		except Exception,e:
			print(str(e))
		saver.restore(persisted_sess, "checkpoint.data")  # 将变量的数据重新加载到各个tensor


		#重现运算
		print(persisted_result.eval())
		print("DONE")
开发者ID:tuling56,项目名称:Python,代码行数:26,代码来源:model_save_restore.py


示例2: run_graph_def

def run_graph_def(graph_def, input_map, outputs):
  graph = tf.Graph()
  with graph.as_default():
    tf.import_graph_def(graph_def, input_map={}, name="")
  with tf.Session(graph=graph) as sess:
    results = sess.run(outputs, feed_dict=input_map)
  return results
开发者ID:DavidNemeskey,项目名称:tensorflow,代码行数:7,代码来源:quantize_graph_test.py


示例3: testInvalidInputForInputMap

 def testInvalidInputForInputMap(self):
   with tf.Graph().as_default():
     with self.assertRaises(TypeError) as e:
       tf.import_graph_def(self._MakeGraphDef(''),
                               input_map=[tf.constant(5.0)])
     self.assertEqual('input_map must be a dictionary mapping strings to '
                      'Tensor objects.', str(e.exception))
开发者ID:yevgeniyfrenkel,项目名称:tensorflow,代码行数:7,代码来源:importer_test.py


示例4: graphdef_to_pbtxt

def graphdef_to_pbtxt(filename): 
  with gfile.FastGFile(filename,'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')
    tf.train.write_graph(graph_def, 'pbtxt/', 'protobuf.pbtxt', as_text=True)
  return
开发者ID:chrhansen,项目名称:tensorflow.rb,代码行数:7,代码来源:converter.py


示例5: __init__

    def __init__(self):
        # Now load the Inception model from file. The way TensorFlow
        # does this is confusing and requires several steps.

        # Create a new TensorFlow computational graph.
        self.graph = tf.Graph()

        # Set the new graph as the default.
        with self.graph.as_default():

            # TensorFlow graphs are saved to disk as so-called Protocol Buffers
            # aka. proto-bufs which is a file-format that works on multiple
            # platforms. In this case it is saved as a binary file.

            # Open the graph-def file for binary reading.
            path = os.path.join(data_dir, path_graph_def)
            with tf.gfile.FastGFile(path, 'rb') as file:
                # The graph-def is a saved copy of a TensorFlow graph.
                # First we need to create an empty graph-def.
                graph_def = tf.GraphDef()

                # Then we load the proto-buf file into the graph-def.
                graph_def.ParseFromString(file.read())

                # Finally we import the graph-def to the default TensorFlow graph.
                tf.import_graph_def(graph_def, name='')

                # Now self.graph holds the Inception model from the proto-buf file.

            # Get a reference to the tensor for inputting images to the graph.
            self.input = self.graph.get_tensor_by_name(self.tensor_name_input_image)

            # Get references to the tensors for the commonly used layers.
            self.layer_tensors = [self.graph.get_tensor_by_name(name + ":0") for name in self.layer_names]
开发者ID:Hvass-Labs,项目名称:TensorFlow-Tutorials,代码行数:34,代码来源:inception5h.py


示例6: strip_and_freeze_until

def strip_and_freeze_until(fetches, graph, sess=None, return_graph=False):
    """
    Create a static view of the graph by

    * Converting all variables into constants
    * Removing graph elements not reachacble to `fetches`

    :param graph: tf.Graph, the graph to be frozen
    :param fetches: list, graph elements representing the outputs of the graph
    :param return_graph: bool, if set True, return the graph function object
    :return: GraphDef, the GraphDef object with cleanup procedure applied
    """
    graph = validated_graph(graph)
    should_close_session = False
    if not sess:
        sess = tf.Session(graph=graph)
        should_close_session = True

    gdef_frozen = tf.graph_util.convert_variables_to_constants(
        sess,
        graph.as_graph_def(add_shapes=True),
        [op_name(graph, tnsr) for tnsr in fetches])

    if should_close_session:
        sess.close()

    if return_graph:
        g = tf.Graph()
        with g.as_default():
            tf.import_graph_def(gdef_frozen, name='')
        return g
    else:
        return gdef_frozen
开发者ID:seanpquig,项目名称:spark-deep-learning,代码行数:33,代码来源:utils.py


示例7: __init__

	def __init__(self, proxy_map):
		super(SpecificWorker, self).__init__(proxy_map)
		self.timer.timeout.connect(self.compute)
		self.Period = 100
		self.timer.start(self.Period)

		# SIFT feature extractor
		self.feature_extractor = cv2.xfeatures2d.SIFT_create()

		# Create a dense grid of keypoints
		self.keypoints=list()
		for i in range(5,IMAGE_SIZE,12):
			for j in range(5,IMAGE_SIZE,12):
				self.keypoints.append(cv2.KeyPoint(i,j,12))

		# Create a tensorflow session
		self.sess=tf.Session()

		# Read the frozen graph from the model file
		with gfile.FastGFile(MODEL_FILE,'rb') as f:
			graph_def = tf.GraphDef()
			graph_def.ParseFromString(f.read())
			self.sess.graph.as_default()
			tf.import_graph_def(graph_def, name='')

			# Get input and output tensors from graph
			self.x_input = self.sess.graph.get_tensor_by_name("input:0")
			self.output = self.sess.graph.get_tensor_by_name("output:0")
			self.dsift = self.sess.graph.get_tensor_by_name("sift:0")
开发者ID:robocomp,项目名称:robocomp-robolab,代码行数:29,代码来源:specificworker.py


示例8: Import

def Import(sess):
    with gfile.FastGFile("../models/producttype/graph.pb",'rb') as f:
        graph_def = tf.GraphDef()
        content = f.read()
        graph_def.ParseFromString(content)
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
开发者ID:daizhen,项目名称:ImagesCategory,代码行数:7,代码来源:import_model.py


示例9: _get_expected_result

def _get_expected_result(gin, local_features):
    """
    Running the graph in the :py:obj:`TFInputGraph` object and compute the expected results.
    :param: gin, a :py:obj:`TFInputGraph`
    :return: expected results in NumPy array
    """
    graph = tf.Graph()
    with tf.Session(graph=graph) as sess, graph.as_default():
        # Build test graph and transformers from here
        tf.import_graph_def(gin.graph_def, name='')

        # Build the results
        _results = []
        for row in local_features:
            fetches = [tfx.get_tensor(tnsr_name, graph)
                       for tnsr_name, _ in _output_mapping.items()]
            feed_dict = {}
            for colname, tnsr_name in _input_mapping.items():
                tnsr = tfx.get_tensor(tnsr_name, graph)
                feed_dict[tnsr] = np.array(row[colname])[np.newaxis, :]

            curr_res = sess.run(fetches, feed_dict=feed_dict)
            _results.append(np.ravel(curr_res))

        expected = np.hstack(_results)

    return expected
开发者ID:pawanrana,项目名称:spark-deep-learning,代码行数:27,代码来源:tf_transformer_test.py


示例10: main

def main(_):
    labels = [line.rstrip() for line in tf.gfile.GFile(FLAGS.output_labels)]

    with tf.gfile.FastGFile(FLAGS.output_graph, 'rb') as fp:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(fp.read())
        tf.import_graph_def(graph_def, name='')

    with tf.Session() as sess:
        logits = sess.graph.get_tensor_by_name('final_result:0')
        image = tf.gfile.FastGFile(sys.argv[1], 'rb').read()
        prediction = sess.run(logits, {'DecodeJpeg/contents:0': image})

    # print('=== 예측 결과 ===')
    # top_result = int(np.argmax(prediction[0]))
    # name = labels[top_result]
    # score = prediction[0][top_result]
    # print('%s (%.2f%%)' % (name, score * 100))

    print('=== 예측 결과 ===')
    for i in range(len(labels)):
        name = labels[i]
        score = prediction[0][i]
        print('%s (%.2f%%)' % (name, score * 100))

    if FLAGS.show_image:
        img = mpimg.imread(sys.argv[1])
        plt.imshow(img)
        plt.show()
开发者ID:superhg2012,项目名称:TensorFlow-Tutorials,代码行数:29,代码来源:predict.py


示例11: classify

    def classify(self, path, resize_height, resize_width):
        """ Resizes the passed image to indicated dimensions and estimates its
            VP using the graph stored self.filename.
        """ 
        self.info("Manually classifying the image in " + str(path))
        # Load freezed graph from file.
        graph_def = tf.GraphDef()
        with open(self.filename, 'rb') as f:
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def)

        predictions = []
        with tf.Session() as sess:
            # Load output node to use for predictions.
            output_node_processed = sess.graph.get_tensor_by_name('import/output_processed:0')
            # Iterate files from directory.
            start_time = time.time()
            # Read image 
            img = cv.imread(path, 1)
            # Process image that will be evaluated by the model.
            img_pred = imresize(img, [resize_height, resize_width], 'bilinear')
            img_pred = img_pred.astype(np.float32)
            img_pred = np.multiply(img_pred, 1.0 / 256.0)
            img_pred = img_pred.flatten()
            # Compute prediction point.
            predictions = output_node_processed.eval(
                feed_dict = {
                    'import/input_images:0': img_pred,
                    'import/keep_prob:0': 1.0
                }
            )
            predictions = np.round(predictions).astype(int)
            self.info('Predicted Point Processed: (' + str(int(round(predictions[0][0]))) + ', ' + str(int(round(predictions[0][1]))) + ')')
        return predictions
开发者ID:se-research-studies,项目名称:2016-itsc,代码行数:34,代码来源:VPClassifier.py


示例12: __init__

    def __init__(self):
        logger.info('Loading Tensorflow Detection API')

        weights_path = get_file(config.SSD_INCEPTION_FILENAME, config.SSD_INCEPTION_URL,
                                cache_dir=os.path.abspath(config.WEIGHT_PATH),
                                cache_subdir='models')

        extract_path = weights_path.replace('.tar.gz', '')
        if not os.path.exists(extract_path):
            tar = tarfile.open(weights_path, "r:gz")
            tar.extractall(path=os.path.join(config.WEIGHT_PATH, 'models'))
            tar.close()
        pb_path = os.path.join(extract_path, self.PB_NAME)

        self.graph = tf.Graph()
        with self.graph.as_default():
            od_graph_def = tf.GraphDef()
            with tf.gfile.GFile(pb_path, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')

        self.label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
        self.categories = label_map_util.convert_label_map_to_categories(self.label_map,
                                                                         max_num_classes=self.NUM_CLASSES,
                                                                         use_display_name=True)
        self.category_index = label_map_util.create_category_index(self.categories)
开发者ID:mohamed-akram,项目名称:pretrained.ml,代码行数:27,代码来源:models.py


示例13: build_prepro_graph

def build_prepro_graph(inception_path):
    global input_layer, output_layer
    with open(inception_path, 'rb') as f:
        fileContent = f.read()

    graph_def = tf.GraphDef()
    graph_def.ParseFromString(fileContent)
    tf.import_graph_def(graph_def)
    graph = tf.get_default_graph()

    input_layer = graph.get_tensor_by_name("import/InputImage:0")
    output_layer = graph.get_tensor_by_name(
        "import/InceptionV4/Logits/AvgPool_1a/AvgPool:0")

    input_file = tf.placeholder(dtype=tf.string, name="InputFile")
    image_file = tf.read_file(input_file)
    jpg = tf.image.decode_jpeg(image_file, channels=3)
    png = tf.image.decode_png(image_file, channels=3)
    output_jpg = tf.image.resize_images(jpg, [299, 299]) / 255.0
    output_jpg = tf.reshape(
        output_jpg, [
            1, 299, 299, 3], name="Preprocessed_JPG")
    output_png = tf.image.resize_images(png, [299, 299]) / 255.0
    output_png = tf.reshape(
        output_png, [
            1, 299, 299, 3], name="Preprocessed_PNG")
    return input_file, output_jpg, output_png
开发者ID:suryawanshishantanu6,项目名称:image-caption-generator,代码行数:27,代码来源:convfeatures.py


示例14: __init__

 def __init__(self, name, input):
     self.name = name
     with open("models/vgg16.tfmodel", mode='rb') as f:
         fileContent = f.read()
     graph_def = tf.GraphDef()
     graph_def.ParseFromString(fileContent)
     tf.import_graph_def(graph_def, input_map={ "images": input }, name=self.name)
开发者ID:fgeorg,项目名称:texture-networks,代码行数:7,代码来源:vgg_network.py


示例15: load_graph

def load_graph(path):
    with tf.gfile.GFile(path, mode='rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="prefix")
    return graph
开发者ID:forin-xyz,项目名称:FoolNLTK,代码行数:7,代码来源:model.py


示例16: main

def main(_):
  # a = tf.constant(5,name="a")
  # b = tf.constant(15,name="b")
  # c = tf.add(a,b,name="c")
  # p = tf.Print(c,[c])
    
  # sess.run(p)
  with tf.device('/cpu:0'):
    t = read_tensor_from_image_file("/home/dek/makerfaire-booth/2018/burger/machine/data/all.299/burgers/burger_000156.png")

  graph = tf.Graph()
  graph_def = tf.GraphDef()
  with tf.Graph().as_default() as graph:
    model_path = '/home/dek/tensorflow/tensorflow/examples/label_image/data/inception_v3_2016_08_28_frozen.pb'
    
    print('Model path: ', model_path)
    with open(model_path, "rb") as f:
      graph_def.ParseFromString(f.read())
    with graph.as_default():
      tf.import_graph_def(graph_def)
    input_op = graph.get_operation_by_name('import/input')
    output_op = graph.get_operation_by_name('import/InceptionV3/Predictions/Reshape_1')
    sess = tf.Session("grpc://localhost:2222")
    results = sess.run(output_op.outputs[0], {
      input_op.outputs[0]: t
    })
    results = np.squeeze(results)

    top_k = results.argsort()[-5:][::-1]
    label_file = "/home/dek/tensorflow/tensorflow/examples/label_image/data/imagenet_slim_labels.txt"
    labels = load_labels(label_file)
    for i in top_k:
      print(labels[i], results[i])
开发者ID:google,项目名称:makerfaire-2016,代码行数:33,代码来源:client.py


示例17: __init__

    def __init__(self, model):

        detection_graph = tf.Graph()

        with detection_graph.as_default():

            od_graph_def = tf.GraphDef()

            with tf.gfile.GFile(model, 'rb') as fid:

                serialized_graph = fid.read()

                od_graph_def.ParseFromString(serialized_graph)

                tf.import_graph_def(od_graph_def, name='')



        self.image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')

        self.detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0')

        self.detection_scores = detection_graph.get_tensor_by_name('detection_scores:0')

        self.detection_classes = detection_graph.get_tensor_by_name('detection_classes:0')

        self.num_detections = detection_graph.get_tensor_by_name('num_detections:0')

        self.sess = tf.Session(graph=detection_graph)
开发者ID:bbcdli,项目名称:xuexi,代码行数:29,代码来源:detector_tfod.py


示例18: get_layer_names

def get_layer_names(model='inception'):
    """Retun every layer's index and name in the given model.

    Parameters
    ----------
    model : str, optional
        Which model to load. Must be one of: ['inception'], 'i2v_tag', 'i2v',
        'vgg16', or 'vgg_face'.

    Returns
    -------
    names : list of tuples
        The index and layer's name for every layer in the given model.
    """
    g = tf.Graph()
    with tf.Session(graph=g):
        if model == 'inception':
            net = inception.get_inception_model()
        elif model == 'vgg_face':
            net = vgg16.get_vgg_face_model()
        elif model == 'vgg16':
            net = vgg16.get_vgg_model()
        elif model == 'i2v':
            net = i2v.get_i2v_model()
        elif model == 'i2v-tag':
            net = i2v.get_i2v_tag_model()

        tf.import_graph_def(net['graph_def'], name='net')
        names = [(i, op.name) for i, op in enumerate(g.get_operations())]
        return names
开发者ID:Liubinggunzu,项目名称:CADL,代码行数:30,代码来源:deepdream.py


示例19: test_i2v

def test_i2v():
    """Loads the i2v network and applies it to a test image.
    """
    with tf.Session() as sess:
        net = get_i2v_model()
        tf.import_graph_def(net['graph_def'], name='i2v')
        g = tf.get_default_graph()
        names = [op.name for op in g.get_operations()]
        x = g.get_tensor_by_name(names[0] + ':0')
        softmax = g.get_tensor_by_name(names[-3] + ':0')

        from skimage import data
        img = preprocess(data.coffee())[np.newaxis]
        res = np.squeeze(softmax.eval(feed_dict={x: img}))
        print([(res[idx], net['labels'][idx])
               for idx in res.argsort()[-5:][::-1]])

        """Let's visualize the network's gradient activation
        when backpropagated to the original input image.  This
        is effectively telling us which pixels contribute to the
        predicted class or given neuron"""
        pools = [name for name in names if 'pool' in name.split('/')[-1]]
        fig, axs = plt.subplots(1, len(pools))
        for pool_i, poolname in enumerate(pools):
            pool = g.get_tensor_by_name(poolname + ':0')
            pool.get_shape()
            neuron = tf.reduce_max(pool, 1)
            saliency = tf.gradients(neuron, x)
            neuron_idx = tf.arg_max(pool, 1)
            this_res = sess.run([saliency[0], neuron_idx],
                                feed_dict={x: img})

            grad = this_res[0][0] / np.max(np.abs(this_res[0]))
            axs[pool_i].imshow((grad * 128 + 128).astype(np.uint8))
            axs[pool_i].set_title(poolname)
开发者ID:Arn-O,项目名称:kadenze-deep-creative-apps,代码行数:35,代码来源:i2v.py


示例20: load_graph

def load_graph(frozen_model_dir):
    """Load frozen tensorflow graph into the default graph.

    Args:
        frozen_model_dir: location of protobuf file containing frozen graph.

    Returns:
        tf.Graph object imported from frozen_model_path.
    """

    # Prase the frozen graph definition into a GraphDef object.
    frozen_file = os.path.join(frozen_model_dir, "frozen_model.pb")
    with tf.gfile.GFile(frozen_file, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    # Load the graph def into the default graph and return it.
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def,
            input_map=None,
            return_elements=None,
            op_dict=None,
            producer_op_list=None)
    return graph
开发者ID:laurii,项目名称:DeepChatModels,代码行数:25,代码来源:web_bot.py



注:本文中的tensorflow.import_graph_def函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python tensorflow.initialize_all_tables函数代码示例发布时间:2022-05-27
下一篇:
Python tensorflow.image_summary函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap