• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python tensorflow.local_variables_initializer函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.local_variables_initializer函数的典型用法代码示例。如果您正苦于以下问题:Python local_variables_initializer函数的具体用法?Python local_variables_initializer怎么用?Python local_variables_initializer使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了local_variables_initializer函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: test

def test(model, config, prompts):

    sr = 24000 if 'blizzard' in config.data_path else 16000
    meta = data_input.load_meta(config.data_path)
    config.r = audio.r
    ivocab = meta['vocab']
    config.vocab_size = len(ivocab)

    with tf.device('/cpu:0'):
        batch_inputs = data_input.load_prompts(prompts, ivocab)
        config.num_prompts = len(prompts)

    with tf.Session() as sess:

        stft_mean = tf.get_variable('stft_mean', shape=(1025*audio.r,), dtype=tf.float16)
        stft_std = tf.get_variable('stft_std', shape=(1025*audio.r,), dtype=tf.float32)

        # initialize model
        model = model(config, batch_inputs, train=False)

        train_writer = tf.summary.FileWriter('log/' + config.save_path + '/test', sess.graph)

        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        saver = tf.train.Saver()

        print('restoring weights')
        latest_ckpt = tf.train.latest_checkpoint(
            'weights/' + config.save_path[:config.save_path.rfind('/')]
        )
        saver.restore(sess, latest_ckpt)

        stft_mean, stft_std = sess.run([stft_mean, stft_std])

        try:
            while(True):
                out = sess.run([
                    model.output,
                    model.alignments,
                    batch_inputs
                ])
                outputs, alignments, inputs = out

                print('saving samples')
                for out, words, align in zip(outputs, inputs['text'], alignments):
                    # store a sample to listen to
                    text = ''.join([ivocab[w] for w in words])
                    attention_plot = data_input.generate_attention_plot(align)
                    sample = audio.invert_spectrogram(out*stft_std + stft_mean)
                    merged = sess.run(tf.summary.merge(
                         [tf.summary.audio(text, sample[None, :], sr),
                          tf.summary.image(text, attention_plot)]
                    ))
                    train_writer.add_summary(merged, 0)
        except tf.errors.OutOfRangeError:
            coord.request_stop()
            coord.join(threads)
开发者ID:yhgon,项目名称:Tacotron-tf-barronalex,代码行数:60,代码来源:test.py


示例2: train

def train(model, data, gen, params):
    anim_frames = []

    with tf.Session() as session:
        tf.local_variables_initializer().run()
        tf.global_variables_initializer().run()

        for step in range(params.num_steps + 1):
            # update discriminator
            x = data.sample(params.batch_size)
            z = gen.sample(params.batch_size)
            loss_d, _, = session.run([model.loss_d, model.opt_d], {
                model.x: np.reshape(x, (params.batch_size, 1)),
                model.z: np.reshape(z, (params.batch_size, 1))
            })

            # update generator
            z = gen.sample(params.batch_size)
            loss_g, _ = session.run([model.loss_g, model.opt_g], {
                model.z: np.reshape(z, (params.batch_size, 1))
            })

            if step % params.log_every == 0:
                print('{}: {:.4f}\t{:.4f}'.format(step, loss_d, loss_g))

            if params.anim_path and (step % params.anim_every == 0):
                anim_frames.append(
                    samples(model, session, data, gen.range, params.batch_size)
                )

        if params.anim_path:
            save_animation(anim_frames, params.anim_path, gen.range)
        else:
            samps = samples(model, session, data, gen.range, params.batch_size)
            plot_distributions(samps, gen.range)
开发者ID:yashvardhan90,项目名称:gan-intro,代码行数:35,代码来源:gan.py


示例3: test_empty_labels_and_scores_gives_nan_auc

 def test_empty_labels_and_scores_gives_nan_auc(self):
   with self.test_session():
     labels = tf.constant([], shape=[0], dtype=tf.bool)
     scores = tf.constant([], shape=[0], dtype=tf.float32)
     score_range = [0, 1.]
     auc, update_op = tf.contrib.metrics.auc_using_histogram(labels, scores,
                                                             score_range)
     tf.local_variables_initializer().run()
     update_op.run()
     self.assertTrue(np.isnan(auc.eval()))
开发者ID:ComeOnGetMe,项目名称:tensorflow,代码行数:10,代码来源:histogram_ops_test.py


示例4: _check_auc

  def _check_auc(self,
                 nbins=100,
                 desired_auc=0.75,
                 score_range=None,
                 num_records=50,
                 frac_true=0.5,
                 atol=0.05,
                 num_updates=10):
    """Check auc accuracy against synthetic data.

    Args:
      nbins:  nbins arg from contrib.metrics.auc_using_histogram.
      desired_auc:  Number in [0, 1].  The desired auc for synthetic data.
      score_range:  2-tuple, (low, high), giving the range of the resultant
        scores.  Defaults to [0, 1.].
      num_records:  Positive integer.  The number of records to return.
      frac_true:  Number in (0, 1).  Expected fraction of resultant labels that
        will be True.  This is just in expectation...more or less may actually
        be True.
      atol:  Absolute tolerance for final AUC estimate.
      num_updates:  Update internal histograms this many times, each with a new
        batch of synthetic data, before computing final AUC.

    Raises:
      AssertionError: If resultant AUC is not within atol of theoretical AUC
        from synthetic data.
    """
    score_range = [0, 1.] or score_range
    with self.test_session():
      labels = tf.placeholder(tf.bool, shape=[num_records])
      scores = tf.placeholder(tf.float32, shape=[num_records])
      auc, update_op = tf.contrib.metrics.auc_using_histogram(labels,
                                                              scores,
                                                              score_range,
                                                              nbins=nbins)
      tf.local_variables_initializer().run()
      # Updates, then extract auc.
      for _ in range(num_updates):
        labels_a, scores_a = synthetic_data(desired_auc, score_range,
                                            num_records, self.rng, frac_true)
        update_op.run(feed_dict={labels: labels_a, scores: scores_a})
      labels_a, scores_a = synthetic_data(desired_auc, score_range, num_records,
                                          self.rng, frac_true)
      # Fetch current auc, and verify that fetching again doesn't change it.
      auc_eval = auc.eval()
      self.assertAlmostEqual(auc_eval, auc.eval(), places=5)

    msg = ('nbins: %s, desired_auc: %s, score_range: %s, '
           'num_records: %s, frac_true: %s, num_updates: %s') % (nbins,
                                                                 desired_auc,
                                                                 score_range,
                                                                 num_records,
                                                                 frac_true,
                                                                 num_updates)
    np.testing.assert_allclose(desired_auc, auc_eval, atol=atol, err_msg=msg)
开发者ID:ComeOnGetMe,项目名称:tensorflow,代码行数:55,代码来源:histogram_ops_test.py


示例5: train

    def train(self, DGTrain, DGTest, saver=True):

        epoch = DGTrain.length

        self.LearningRateSchedule(self.LEARNING_RATE, self.K, epoch)

        trainable_var = tf.trainable_variables()
        
        self.regularize_model()
        self.optimization(trainable_var)
        self.ExponentialMovingAverage(trainable_var, self.DECAY_EMA)

        tf.global_variables_initializer().run()
        tf.local_variables_initializer().run()

        self.summary_test_writer = tf.summary.FileWriter(self.LOG + '/test',
                                            graph=self.sess.graph)

        self.summary_writer = tf.summary.FileWriter(self.LOG + '/train', graph=self.sess.graph)
        merged_summary = tf.summary.merge_all()
        steps = self.STEPS

        
        # for i in range(Xval.shape[0]):
        #     imsave("/tmp/image_{}.png".format(i), Xval[i])
        #     imsave("/tmp/label_{}.png".format(i), Yval[i,:,:,0])



        for step in range(steps):
            batch_data, batch_labels = DGTrain.Batch(0, self.BATCH_SIZE)
            feed_dict = {self.input_node: batch_data,
                         self.train_labels_node: batch_labels}

            # self.optimizer is replaced by self.training_op for the exponential moving decay
            _, l, lr, predictions, s = self.sess.run(
                        [self.training_op, self.loss, self.learning_rate,
                         self.train_prediction, merged_summary],
                        feed_dict=feed_dict)

            if step % self.N_PRINT == 0:
                i = datetime.now()
                print i.strftime('%Y/%m/%d %H:%M:%S: \n ')
                self.summary_writer.add_summary(s, step)                
                error, acc, acc1, recall, prec, f1 = self.error_rate(predictions, batch_labels, step)
                print('  Step %d of %d' % (step, steps))
                print('  Learning rate: %.5f \n') % lr
                print('  Mini-batch loss: %.5f \n       Accuracy: %.1f%% \n       acc1: %.1f%% \n       recall: %1.f%% \n       prec: %1.f%% \n       f1 : %1.f%% \n' % 
                      (l, acc, acc1, recall, prec, f1))
                self.Validation(DGTest, step)
开发者ID:PeterJackNaylor,项目名称:PhD_Fabien,代码行数:50,代码来源:ObjectOriented.py


示例6: main

def main(model_config, train_config, track_config):
  # Create training directory
  train_dir = train_config['train_dir']
  if not tf.gfile.IsDirectory(train_dir):
    tf.logging.info('Creating training directory: %s', train_dir)
    tf.gfile.MakeDirs(train_dir)

  # Build the Tensorflow graph
  g = tf.Graph()
  with g.as_default():
    # Set fixed seed
    np.random.seed(train_config['seed'])
    tf.set_random_seed(train_config['seed'])

    # Build the model
    model = siamese_model.SiameseModel(model_config, train_config, mode='inference')
    model.build()

    # Save configurations for future reference
    save_cfgs(train_dir, model_config, train_config, track_config)

    saver = tf.train.Saver(tf.global_variables(),
                           max_to_keep=train_config['max_checkpoints_to_keep'])

    # Dynamically allocate GPU memory
    gpu_options = tf.GPUOptions(allow_growth=True)
    sess_config = tf.ConfigProto(gpu_options=gpu_options)

    sess = tf.Session(config=sess_config)
    model_path = tf.train.latest_checkpoint(train_config['train_dir'])

    if not model_path:
      # Initialize all variables
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      start_step = 0

      # Load pretrained embedding model if needed
      if model_config['embed_config']['embedding_checkpoint_file']:
        model.init_fn(sess)

    else:
      logging.info('Restore from last checkpoint: {}'.format(model_path))
      sess.run(tf.local_variables_initializer())
      saver.restore(sess, model_path)
      start_step = tf.train.global_step(sess, model.global_step.name) + 1

    checkpoint_path = osp.join(train_config['train_dir'], 'model.ckpt')
    saver.save(sess, checkpoint_path, global_step=start_step)
开发者ID:fossabot,项目名称:SiamFC-TensorFlow,代码行数:49,代码来源:convert_pretrained_model.py


示例7: test_batch_text_lines

  def test_batch_text_lines(self):
    gfile.Glob = self._orig_glob
    filename = self._create_temp_file("A\nB\nC\nD\nE\n")

    batch_size = 3
    queue_capacity = 10
    name = "my_batch"

    with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
      inputs = tf.contrib.learn.io.read_batch_examples(
          [filename], batch_size, reader=tf.TextLineReader,
          randomize_input=False, num_epochs=1, queue_capacity=queue_capacity,
          read_batch_size=10, name=name)
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run(tf.local_variables_initializer())

      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(session, coord=coord)

      self.assertAllEqual(session.run(inputs), [b"A", b"B", b"C"])
      self.assertAllEqual(session.run(inputs), [b"D", b"E"])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)
开发者ID:moolighty,项目名称:tensorflow,代码行数:26,代码来源:graph_io_test.py


示例8: testRoundtrip

    def testRoundtrip(self, rate=0.25, count=5, n=500):
        """Tests `resample(x, weights)` and resample(resample(x, rate), 1/rate)`."""

        foo = self.get_values(count)
        bar = self.get_values(count)
        weights = self.get_weights(count)

        resampled_in, rates = tf.contrib.training.weighted_resample([foo, bar], tf.constant(weights), rate, seed=123)

        resampled_back_out = tf.contrib.training.resample_at_rate(resampled_in, 1.0 / rates, seed=456)

        init = tf.local_variables_initializer()
        with self.test_session() as s:
            s.run(init)  # initialize

            # outputs
            counts_resampled = collections.Counter()
            counts_reresampled = collections.Counter()
            for _ in range(n):
                resampled_vs, reresampled_vs = s.run([resampled_in, resampled_back_out])

                self.assertAllEqual(resampled_vs[0], resampled_vs[1])
                self.assertAllEqual(reresampled_vs[0], reresampled_vs[1])

                for v in resampled_vs[0]:
                    counts_resampled[v] += 1
                for v in reresampled_vs[0]:
                    counts_reresampled[v] += 1

            # assert that resampling worked as expected
            self.assert_expected(weights, rate, counts_resampled, n)

            # and that re-resampling gives the approx identity.
            self.assert_expected([1.0 for _ in weights], 1.0, counts_reresampled, n, abs_delta=0.1 * n * count)
开发者ID:brchiu,项目名称:tensorflow,代码行数:34,代码来源:resample_test.py


示例9: blend_images

def blend_images(data_folder1, data_folder2, out_folder, alpha=.5):
    filename_queue = tf.placeholder(dtype=tf.string)
    label = tf.placeholder(dtype=tf.int32)
    tensor_image = tf.read_file(filename_queue)

    image = tf.image.decode_jpeg(tensor_image, channels=3)

    multiplier = tf.div(tf.constant(224, tf.float32),
                        tf.cast(tf.maximum(tf.shape(image)[0], tf.shape(image)[1]), tf.float32))
    x = tf.cast(tf.round(tf.mul(tf.cast(tf.shape(image)[0], tf.float32), multiplier)), tf.int32)
    y = tf.cast(tf.round(tf.mul(tf.cast(tf.shape(image)[1], tf.float32), multiplier)), tf.int32)
    image = tf.image.resize_images(image, [x, y])

    image = tf.image.rot90(image, k=label)

    image = tf.image.resize_image_with_crop_or_pad(image, 224, 224)
    sess = tf.Session()
    sess.run(tf.local_variables_initializer())
    for root, folders, files in os.walk(data_folder1):
        for each in files:
            if each.find('.jpg') >= 0:
                img1 = Image.open(os.path.join(root, each))
                img2_path = os.path.join(root.replace(data_folder1, data_folder2), each.split("-")[-1])
                rotation = int(each.split("-")[1])
                img2 = sess.run(image, feed_dict={filename_queue: img2_path, label: rotation})
                imsave(os.path.join(os.getcwd(), "temp", "temp.jpg"), img2)
                img2 = Image.open(os.path.join(os.getcwd(), "temp", "temp.jpg"))
                out_image = Image.blend(img1, img2, alpha)
                outfile = os.path.join(root.replace(data_folder1, out_folder), each)
                if not os.path.exists(os.path.split(outfile)[0]):
                    os.makedirs(os.path.split(outfile)[0])
                out_image.save(outfile)
            else:
                print(each)
    sess.close()
开发者ID:Sabrewarrior,项目名称:PhotoOrientation,代码行数:35,代码来源:misc.py


示例10: test

    def test(self, p1, p2, steps):
        loss, roc = 0., 0.
        acc, F1, recall = 0., 0., 0.
        precision, jac, AJI = 0., 0., 0.
        init_op = tf.group(tf.global_variables_initializer(),
                   tf.local_variables_initializer())
        self.sess.run(init_op)
        self.Saver()
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        for step in range(steps):  
            feed_dict = {self.is_training: False} 
            l,  prob, batch_labels = self.sess.run([self.loss, self.train_prediction,
                                                               self.train_labels_node], feed_dict=feed_dict)
            loss += l
            out = ComputeMetrics(prob[0,:,:,1], batch_labels[0,:,:,0], p1, p2)
            acc += out[0]
            roc += out[1]
            jac += out[2]
            recall += out[3]
            precision += out[4]
            F1 += out[5]
            AJI += out[6]
        coord.request_stop()
        coord.join(threads)
        loss, acc, F1 = np.array([loss, acc, F1]) / steps
        recall, precision, roc = np.array([recall, precision, roc]) / steps
        jac, AJI = np.array([jac, AJI]) / steps
        return loss, acc, F1, recall, precision, roc, jac, AJI
开发者ID:PeterJackNaylor,项目名称:PhD_Fabien,代码行数:30,代码来源:UNet.py


示例11: initialize_variables

def initialize_variables(sess, saver, logdir, checkpoint=None, resume=None):
  """Initialize or restore variables from a checkpoint if available.

  Args:
    sess: Session to initialize variables in.
    saver: Saver to restore variables.
    logdir: Directory to search for checkpoints.
    checkpoint: Specify what checkpoint name to use; defaults to most recent.
    resume: Whether to expect recovering a checkpoint or starting a new run.

  Raises:
    ValueError: If resume expected but no log directory specified.
    RuntimeError: If no resume expected but a checkpoint was found.
  """
  sess.run(tf.group(
      tf.local_variables_initializer(),
      tf.global_variables_initializer()))
  if resume and not (logdir or checkpoint):
    raise ValueError('Need to specify logdir to resume a checkpoint.')
  if logdir:
    state = tf.train.get_checkpoint_state(logdir)
    if checkpoint:
      checkpoint = os.path.join(logdir, checkpoint)
    if not checkpoint and state and state.model_checkpoint_path:
      checkpoint = state.model_checkpoint_path
    if checkpoint and resume is False:
      message = 'Found unexpected checkpoint when starting a new run.'
      raise RuntimeError(message)
    if checkpoint:
      saver.restore(sess, checkpoint)
开发者ID:shamanez,项目名称:agents,代码行数:30,代码来源:utility.py


示例12: testSummariesAreFlushedToDiskWithoutGlobalStep

  def testSummariesAreFlushedToDiskWithoutGlobalStep(self):
    output_dir = os.path.join(self.get_temp_dir(), 'flush_test_no_global_step')
    if tf.gfile.Exists(output_dir):  # For running on jenkins.
      tf.gfile.DeleteRecursively(output_dir)

    names_to_metrics, names_to_updates = self._create_names_to_metrics(
        self._predictions, self._labels)

    for k in names_to_metrics:
      v = names_to_metrics[k]
      tf.summary.scalar(k, v)

    summary_writer = tf.train.SummaryWriter(output_dir)

    initial_op = tf.group(tf.global_variables_initializer(),
                          tf.local_variables_initializer())
    eval_op = tf.group(*names_to_updates.values())

    with self.test_session() as sess:
      slim.evaluation.evaluation(
          sess,
          initial_op=initial_op,
          eval_op=eval_op,
          summary_op=tf.summary.merge_all(),
          summary_writer=summary_writer)

      names_to_values = {name: names_to_metrics[name].eval()
                         for name in names_to_metrics}
    self._verify_summaries(output_dir, names_to_values)
开发者ID:ComeOnGetMe,项目名称:tensorflow,代码行数:29,代码来源:evaluation_test.py


示例13: run

def run():
    with tf.Session() as sess:
        print("start")
        feature = {'image': tf.FixedLenFeature([], tf.string),
                   'label': tf.FixedLenFeature([], tf.int64)}
        # Create a list of filenames and pass it to a queue
        print(data_path)
        filename_queue = tf.train.string_input_producer(data_path, num_epochs=1)
        # Define a reader and read the next record
        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)
        # Decode the record read by the reader
        features = tf.parse_single_example(serialized_example, features=feature)
        # Convert the image data from string back to the numbers
        image = tf.decode_raw(features['image'], tf.uint8)
        # image = tf.cast(image, tf.int32)

        # Cast label data into int32
        label = tf.cast(features['label'], tf.int32)
        # Reshape image data into the original shape
        init_op = [tf.global_variables_initializer(), tf.local_variables_initializer()]
        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        train_list = []
        for i in range(1000):
            example, l = sess.run([image, label])
            train_list.append((example,l))
            # print (example, l)
        coord.request_stop()
        coord.join(threads)
        return train_list
# run()
开发者ID:ykakde,项目名称:trash-classifier,代码行数:34,代码来源:tf_file_reader.py


示例14: main

def main(argv):
  del argv  # Unused.
  # Sanity check on the GCS bucket URL.
  if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"):
    print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url)
    sys.exit(1)

  # Verify that writing to the records file in GCS works.
  print("\n=== Testing writing and reading of GCS record file... ===")
  example_data = create_examples(FLAGS.num_examples, 5)
  with tf.python_io.TFRecordWriter(FLAGS.gcs_bucket_url) as hf:
    for e in example_data:
      hf.write(e.SerializeToString())

    print("Data written to: %s" % FLAGS.gcs_bucket_url)

  # Verify that reading from the tfrecord file works and that
  # tf_record_iterator works.
  record_iter = tf.python_io.tf_record_iterator(FLAGS.gcs_bucket_url)
  read_count = 0
  for _ in record_iter:
    read_count += 1
  print("Read %d records using tf_record_iterator" % read_count)

  if read_count != FLAGS.num_examples:
    print("FAIL: The number of records read from tf_record_iterator (%d) "
          "differs from the expected number (%d)" % (read_count,
                                                     FLAGS.num_examples))
    sys.exit(1)

  # Verify that running the read op in a session works.
  print("\n=== Testing TFRecordReader.read op in a session... ===")
  with tf.Graph().as_default() as _:
    filename_queue = tf.train.string_input_producer([FLAGS.gcs_bucket_url],
                                                    num_epochs=1)
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      sess.run(tf.local_variables_initializer())
      tf.train.start_queue_runners()
      index = 0
      for _ in range(FLAGS.num_examples):
        print("Read record: %d" % index)
        sess.run(serialized_example)
        index += 1

      # Reading one more record should trigger an exception.
      try:
        sess.run(serialized_example)
        print("FAIL: Failed to catch the expected OutOfRangeError while "
              "reading one more record than is available")
        sys.exit(1)
      except tf.errors.OutOfRangeError:
        print("Successfully caught the expected OutOfRangeError while "
              "reading one more record than is available")

  create_dir_test()
  create_object_test()
开发者ID:DILASSS,项目名称:tensorflow,代码行数:60,代码来源:gcs_smoke.py


示例15: predict

	def predict(self):
		import cv2
		import glob
		import numpy as np
		# TODO 不应该这样写,应该直接读图片预测,而不是从tfrecord读取,因为顺序变了,无法对应
		predict_file_path = glob.glob(os.path.join(ORIGIN_PREDICT_DIRECTORY, '*.tif'))
		print(len(predict_file_path))
		ckpt_path = CHECK_POINT_PATH
		all_parameters_saver = tf.train.Saver()
		with tf.Session() as sess:  # 开始一个会话
			sess.run(tf.global_variables_initializer())
			sess.run(tf.local_variables_initializer())
			# summary_writer = tf.summary.FileWriter(FLAGS.tb_dir, sess.graph)
			# tf.summary.FileWriter(FLAGS.model_dir, sess.graph)
			all_parameters_saver.restore(sess=sess, save_path=ckpt_path)
			for index, image_path in enumerate(predict_file_path):
				# image = cv2.imread(image_path, flags=0)
				image = np.reshape(a=cv2.imread(image_path, flags=0), newshape=(1, INPUT_IMG_WIDE, INPUT_IMG_HEIGHT, INPUT_IMG_CHANNEL))
				predict_image = sess.run(
					tf.argmax(input=self.prediction, axis=3),
					feed_dict={
						self.input_image: image,
						self.keep_prob: 1.0, self.lamb: 0.004
					}
				)
				cv2.imwrite(os.path.join(PREDICT_SAVED_DIRECTORY, '%d.jpg' % index), predict_image[0] * 255)
		print('Done prediction')
开发者ID:USTCzxm,项目名称:U-net,代码行数:27,代码来源:unet-TF.py


示例16: test_input_pipeline

 def test_input_pipeline(self):
     Xs, Ys = dsu.tiny_imagenet_load()
     n_batches = 0
     batch_size = 10
     with tf.Graph().as_default(), tf.Session() as sess:
         batch_generator = dsu.create_input_pipeline(
             Xs[:100],
             batch_size=batch_size,
             n_epochs=1,
             shape=(64, 64, 3),
             crop_shape=(64, 64, 3))
         init_op = tf.group(tf.global_variables_initializer(),
                            tf.local_variables_initializer())
         sess.run(init_op)
         coord = tf.train.Coordinator()
         tf.get_default_graph().finalize()
         threads = tf.train.start_queue_runners(sess=sess, coord=coord)
         try:
             while not coord.should_stop():
                 batch = sess.run(batch_generator)
                 assert (batch.shape == (batch_size, 64, 64, 3))
                 n_batches += 1
         except tf.errors.OutOfRangeError:
             pass
         finally:
             coord.request_stop()
         coord.join(threads)
     assert (n_batches == 10)
开发者ID:pradeeps,项目名称:pycadl,代码行数:28,代码来源:test_dataset_utils.py


示例17: get_hit_rate_and_ndcg

  def get_hit_rate_and_ndcg(self, predicted_scores_by_user, items_by_user,
                            top_k=rconst.TOP_K, match_mlperf=False):
    rconst.TOP_K = top_k
    rconst.NUM_EVAL_NEGATIVES = predicted_scores_by_user.shape[1] - 1
    batch_size = items_by_user.shape[0]

    users = np.repeat(np.arange(batch_size)[:, np.newaxis],
                      rconst.NUM_EVAL_NEGATIVES + 1, axis=1)
    users, items, duplicate_mask = \
      data_pipeline.BaseDataConstructor._assemble_eval_batch(
          users, items_by_user[:, -1:], items_by_user[:, :-1], batch_size)

    g = tf.Graph()
    with g.as_default():
      logits = tf.convert_to_tensor(
          predicted_scores_by_user.reshape((-1, 1)), tf.float32)
      softmax_logits = tf.concat([tf.zeros(logits.shape, dtype=logits.dtype),
                                  logits], axis=1)
      duplicate_mask = tf.convert_to_tensor(duplicate_mask, tf.float32)

      metric_ops = neumf_model.compute_eval_loss_and_metrics(
          logits=logits, softmax_logits=softmax_logits,
          duplicate_mask=duplicate_mask, num_training_neg=NUM_TRAIN_NEG,
          match_mlperf=match_mlperf).eval_metric_ops

      hr = metric_ops[rconst.HR_KEY]
      ndcg = metric_ops[rconst.NDCG_KEY]

      init = [tf.global_variables_initializer(),
              tf.local_variables_initializer()]

    with self.test_session(graph=g) as sess:
      sess.run(init)
      return sess.run([hr[1], ndcg[1]])
开发者ID:pooyadavoodi,项目名称:models,代码行数:34,代码来源:ncf_test.py


示例18: test_smoke

    def test_smoke(self):
        """Smoke test for a full pipeline."""
        _, tname = tempfile.mkstemp()
        num = 100
        num_epochs = 2
        self._write_examples(tname, [self._random_io_data() for _ in range(num)])
        tensors = data.read_from_files([tname], shuffle=True, num_epochs=num_epochs)
        batches = lin.shuffle_batch(tensors=tensors, batch_size=5)

        count = 0
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            try:
                while True:
                    actual = sess.run(batches)
                    count += len(actual[0])
            except tf.errors.OutOfRangeError as ex:
                coord.request_stop(ex=ex)
            finally:
                coord.request_stop()
                coord.join(threads)
        self.assertEqual(num * num_epochs, count)
        os.remove(tname)
开发者ID:usman776,项目名称:dket,代码行数:26,代码来源:test_data.py


示例19: test_keyed_read_text_lines

  def test_keyed_read_text_lines(self):
    gfile.Glob = self._orig_glob
    filename = self._create_temp_file("ABC\nDEF\nGHK\n")

    batch_size = 1
    queue_capacity = 5
    name = "my_batch"

    with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
      keys, inputs = tf.contrib.learn.io.read_keyed_batch_examples(
          filename, batch_size,
          reader=tf.TextLineReader, randomize_input=False,
          num_epochs=1, queue_capacity=queue_capacity, name=name)
      self.assertAllEqual((None,), keys.get_shape().as_list())
      self.assertAllEqual((None,), inputs.get_shape().as_list())
      session.run(tf.local_variables_initializer())

      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(session, coord=coord)

      self.assertAllEqual(session.run([keys, inputs]),
                          [[filename.encode("utf-8") + b":1"], [b"ABC"]])
      self.assertAllEqual(session.run([keys, inputs]),
                          [[filename.encode("utf-8") + b":2"], [b"DEF"]])
      self.assertAllEqual(session.run([keys, inputs]),
                          [[filename.encode("utf-8") + b":3"], [b"GHK"]])
      with self.assertRaises(errors.OutOfRangeError):
        session.run(inputs)

      coord.request_stop()
      coord.join(threads)
开发者ID:moolighty,项目名称:tensorflow,代码行数:31,代码来源:graph_io_test.py


示例20: predict

	def predict(self):
		print 'Running inference...'
		self.sess.run(tf.group(tf.global_variables_initializer(),tf.local_variables_initializer()))
		self.load_weights('/Users/shashank/TensorFlow/SPN/weights/')
		coord = tf.train.Coordinator()
		threads = tf.train.start_queue_runners(sess=self.sess,coord=coord)

		result = []
		truth = []
		count =0
		try:
			while not coord.should_stop():
				print count
				batch_imgs, batch_labels, batch_landmarks, batch_visibility, batch_pose, batch_gender = self.sess.run([self.images,self.labels,self.land, self.vis, self.po, self.gen])
				batch_imgs = (batch_imgs - 127.5) / 128.0
				
				net_preds = self.sess.run(self.net_output, feed_dict={self.X: batch_imgs})
				result.append(np.concatenate(net_preds, axis=1))
				truth.append(np.concatenate([batch_labels[:, np.newaxis], batch_landmarks, batch_visibility, batch_pose, batch_gender], axis=1))
				count += 1

		except tf.errors.OutOfRangeError:
			print('Done training -- epoch limit reached')
		finally:
			coord.request_stop()

		coord.join(threads)	
		np.save('test_results', np.concatenate(result, axis = 0))
		np.save('truth', np.concatenate(truth, axis = 0))
开发者ID:dmehr,项目名称:HyperFace-TensorFlow-implementation,代码行数:29,代码来源:model_prediction.py



注:本文中的tensorflow.local_variables_initializer函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python tensorflow.log函数代码示例发布时间:2022-05-27
下一篇:
Python tensorflow.local_variables函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap