本文整理汇总了Python中numpy.stack函数的典型用法代码示例。如果您正苦于以下问题:Python stack函数的具体用法?Python stack怎么用?Python stack使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了stack函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: arrow3d
def arrow3d(base, r1, r2, ort, l, h, m = 13, pivot = 'tail'):
x = np.array([1., 0., 0.])
y = np.array([0., 1., 0.])
th = np.linspace(0, np.pi*2, m).reshape(-1,1)
ort = norm_vec(ort)
if np.sum(ort * x) == 0:
d1 = norm_vec(np.cross(ort, y))
else:
d1 = norm_vec(np.cross(ort, x))
if pivot == 'tip':
base = base - (l+h)*ort
elif pivot == 'mid':
base = base - (l+h)*ort/2.
else:
pass
d2 = np.cross(ort, d1)
p = base + l*r1* (d1*np.cos(th) + d2*np.sin(th))
q = p + l*ort
p2 = base + l*r2* (d1*np.cos(th) + d2*np.sin(th)) + l*ort
p3 = base + (l+h)*ort
p3 = np.array([p3]*m).reshape(-1, 3)
t1 = np.stack((p[:-1], q[:-1], p[1:]), axis=1)
t2 = np.stack((p[1:], q[:-1], q[1:]), axis=1)
t3 = np.stack((p2[:-1], p3[:-1], p2[1:]), axis=1)
#t2 = np.dstack((p[1:], q[:-1], q[1:]))
t1 = np.vstack((t1, t2, t3))
return t1
开发者ID:piScope,项目名称:piScope,代码行数:27,代码来源:axes3d_mod.py
示例2: check_rnn_forward
def check_rnn_forward(layer, inputs, deterministic=True):
if isinstance(inputs, mx.nd.NDArray):
inputs.attach_grad()
else:
for x in inputs:
x.attach_grad()
layer.collect_params().initialize()
with mx.autograd.record():
out = layer.unroll(3, inputs, merge_outputs=False)[0]
mx.autograd.backward(out)
out = layer.unroll(3, inputs, merge_outputs=True)[0]
out.backward()
np_out = out.asnumpy()
if isinstance(inputs, mx.nd.NDArray):
np_dx = inputs.grad.asnumpy()
else:
np_dx = np.stack([x.grad.asnumpy() for x in inputs], axis=1)
layer.hybridize()
with mx.autograd.record():
out = layer.unroll(3, inputs, merge_outputs=False)[0]
mx.autograd.backward(out)
out = layer.unroll(3, inputs, merge_outputs=True)[0]
out.backward()
if isinstance(inputs, mx.nd.NDArray):
input_grads = inputs.grad.asnumpy()
else:
input_grads = np.stack([x.grad.asnumpy() for x in inputs], axis=1)
if deterministic:
mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, atol=1e-5)
mx.test_utils.assert_almost_equal(np_dx, input_grads, rtol=1e-3, atol=1e-5)
开发者ID:dpom,项目名称:incubator-mxnet,代码行数:35,代码来源:test_gluon_rnn.py
示例3: step
def step(self, action):
"""Forward a batch of actions to the wrapped environments.
Args:
action: Batched action to apply to the environment.
Raises:
ValueError: Invalid actions.
Returns:
Batch of observations, rewards, and done flags.
"""
actions = action
for index, (env, action) in enumerate(zip(self._envs, actions)):
if not env.action_space.contains(action):
message = 'Invalid action at index {}: {}'
raise ValueError(message.format(index, action))
if self._blocking:
transitions = [
env.step(action)
for env, action in zip(self._envs, actions)]
else:
transitions = [
env.step(action, blocking=False)
for env, action in zip(self._envs, actions)]
transitions = [transition() for transition in transitions]
observs, rewards, dones, infos = zip(*transitions)
observ = np.stack(observs)
reward = np.stack(rewards)
done = np.stack(dones)
info = tuple(infos)
return observ, reward, done, info
开发者ID:AndrewMeadows,项目名称:bullet3,代码行数:32,代码来源:batch_env.py
示例4: split_data
def split_data(chars, batch_size, num_steps, split_frac=0.9):
"""
Split character data into training and validation sets, inputs and targets for each set.
Arguments
---------
chars: character array
batch_size: Size of examples in each of batch
num_steps: Number of sequence steps to keep in the input and pass to the network
split_frac: Fraction of batches to keep in the training set
Returns train_x, train_y, val_x, val_y
"""
slice_size = batch_size * num_steps
n_batches = int(len(chars) / slice_size)
# Drop the last few characters to make only full batches
x = chars[: n_batches * slice_size]
y = chars[1: n_batches * slice_size + 1]
# Split the data into batch_size slices, then stack them into a 2D matrix
x = np.stack(np.split(x, batch_size))
y = np.stack(np.split(y, batch_size))
# Now x and y are arrays with dimensions batch_size x n_batches*num_steps
# Split into training and validation sets, keep the first split_frac batches for training
split_idx = int(n_batches * split_frac)
train_x, train_y = x[:, :split_idx * num_steps], y[:, :split_idx * num_steps]
val_x, val_y = x[:, split_idx * num_steps:], y[:, split_idx * num_steps:]
return train_x, train_y, val_x, val_y
开发者ID:Eudie,项目名称:Online-Practice,代码行数:31,代码来源:building_text_generator.py
示例5: load_mask_labels
def load_mask_labels():
'''Load both target and style masks.
A mask image (nr x nc) with m labels/colors will be loaded
as a 4D boolean tensor: (1, m, nr, nc) for 'th' or (1, nr, nc, m) for 'tf'
'''
target_mask_img = load_img(target_mask_path,
target_size=(img_nrows, img_ncols))
target_mask_img = img_to_array(target_mask_img)
style_mask_img = load_img(style_mask_path,
target_size=(img_nrows, img_ncols))
style_mask_img = img_to_array(style_mask_img)
if K.image_dim_ordering() == 'th':
mask_vecs = np.vstack([style_mask_img.reshape((3, -1)).T,
target_mask_img.reshape((3, -1)).T])
else:
mask_vecs = np.vstack([style_mask_img.reshape((-1, 3)),
target_mask_img.reshape((-1, 3))])
labels = kmeans(mask_vecs, nb_labels)
style_mask_label = labels[:img_nrows *
img_ncols].reshape((img_nrows, img_ncols))
target_mask_label = labels[img_nrows *
img_ncols:].reshape((img_nrows, img_ncols))
stack_axis = 0 if K.image_dim_ordering() == 'th' else -1
style_mask = np.stack([style_mask_label == r for r in xrange(nb_labels)],
axis=stack_axis)
target_mask = np.stack([target_mask_label == r for r in xrange(nb_labels)],
axis=stack_axis)
return (np.expand_dims(style_mask, axis=0),
np.expand_dims(target_mask, axis=0))
开发者ID:AnishShah,项目名称:keras,代码行数:32,代码来源:neural_doodle.py
示例6: formatPeaksArbitraryPSF
def formatPeaksArbitraryPSF(peaks, peaks_type):
"""
Input peaks array formatter for arbitrary PSFs.
Based on peaks_type, create a properly formatted ndarray to pass
to the C library. This is primarily for internal use by newPeaks().
"""
# These come from the finder, or the unit test code, create peaks
# as (N,3) with columns x, y, z.
#
if (peaks_type == "testing") or (peaks_type == "finder"):
c_peaks = numpy.stack((peaks["x"],
peaks["y"],
peaks["z"]), axis = 1)
# These come from pre-specified peak fitting locations, create peaks
# as (N,5) with columns x, y, z, background, height.
#
elif (peaks_type == "text") or (peaks_type == "hdf5"):
c_peaks = numpy.stack((peaks["x"],
peaks["y"],
peaks["z"],
peaks["background"],
peaks["height"]), axis = 1)
else:
raise MultiFitterException("Unknown peaks type '" + peaks_type + "'")
return numpy.ascontiguousarray(c_peaks, dtype = numpy.float64)
开发者ID:ZhuangLab,项目名称:storm-analysis,代码行数:28,代码来源:dao_fit_c.py
示例7: converter
def converter(batch, device, max_caption_length=None):
"""Optional preprocessing of the batch before forward pass."""
pad = max_caption_length is not None
imgs = []
captions = []
for img, caption in batch:
# Preproess the caption by either fixing the length by padding (LSTM)
# or by simply wrapping each caption in an ndarray (NStepLSTM)
if pad:
arr = np.full(max_caption_length, _ignore, dtype=np.int32)
# Clip to max length if necessary
arr[:len(caption)] = caption[:max_caption_length]
caption = arr
else:
caption = to_device(device, np.asarray(caption, dtype=np.int32))
imgs.append(img)
captions.append(caption)
if pad:
captions = to_device(device, np.stack(captions))
imgs = to_device(device, np.stack(imgs))
return imgs, captions
开发者ID:Fhrozen,项目名称:chainer,代码行数:26,代码来源:datasets.py
示例8: get_filters
def get_filters(R, filter_size, P=None, n_rings=None):
"""Perform single-frequency DFT on each ring of a polar-resampled patch"""
k = filter_size
filters = {}
N = n_samples(k)
from scipy.linalg import dft
for m, r in R.iteritems():
rsh = r.shape
# Get the basis matrices
weights = get_interpolation_weights(k, m, n_rings=n_rings)
DFT = dft(N)[m,:]
LPF = np.dot(DFT, weights).T
cosine = np.real(LPF).astype(np.float32)
sine = np.imag(LPF).astype(np.float32)
# Project taps on to rotational basis
r = np.reshape(r, np.stack([rsh[0],rsh[1]*rsh[2]]))
ucos = np.reshape(np.dot(cosine, r), np.stack([k, k, rsh[1], rsh[2]]))
usin = np.reshape(np.dot(sine, r), np.stack([k, k, rsh[1], rsh[2]]))
if P is not None:
# Rotate basis matrices
ucos_ = np.cos(P[m])*ucos + np.sin(P[m])*usin
usin = -np.sin(P[m])*ucos + np.cos(P[m])*usin
ucos = ucos_
filters[m] = (ucos, usin)
return filters
开发者ID:deworrall92,项目名称:groupConvolutions,代码行数:25,代码来源:numpy_hconv.py
示例9: _read
def _read(self, key):
ifnone = lambda a, b: b if a is None else a
y = key[1]
x = key[2]
if isinstance(x, slice):
xstart = ifnone(x.start,0)
xstop = ifnone(x.stop,self.raster_size[0])
xstep = xstop - xstart
else:
raise TypeError("Loc style access elements must be slices, e.g., [:] or [10:100]")
if isinstance(y, slice):
ystart = ifnone(y.start, 0)
ystop = ifnone(y.stop, self.raster_size[1])
ystep = ystop - ystart
else:
raise TypeError("Loc style access elements must be slices, e.g., [:] or [10:100]")
pixels = (xstart, ystart, xstep, ystep)
if isinstance(key[0], (int, np.integer)):
return self.read_array(band=int(key[0]+1), pixels=pixels)
elif isinstance(key[0], slice):
# Given some slice iterate over the bands and get the bands and pixel space requested
arrs = []
for band in list(list(range(1, self.nbands + 1))[key[0]]):
arrs.append(self.read_array(band, pixels = pixels))
return np.stack(arrs)
else:
arrs = []
for b in key[0]:
arrs.append(self.read_array(band=int(b+1), pixels=pixels))
return np.stack(arrs)
开发者ID:USGS-Astrogeology,项目名称:plio,代码行数:34,代码来源:hcube.py
示例10: translist_to_traj
def translist_to_traj(tlist):
obs_T_Do = np.stack([trans[0] for trans in tlist]); assert obs_T_Do.shape == (len(tlist), self.obs_space.storage_size)
obsfeat_T_Df = np.stack([trans[1] for trans in tlist]); assert obsfeat_T_Df.shape[0] == len(tlist)
adist_T_Pa = np.stack([trans[2] for trans in tlist]); assert adist_T_Pa.ndim == 2 and adist_T_Pa.shape[0] == len(tlist)
a_T_Da = np.stack([trans[3] for trans in tlist]); assert a_T_Da.shape == (len(tlist), self.action_space.storage_size)
r_T = np.stack([trans[4] for trans in tlist]); assert r_T.shape == (len(tlist),)
return Trajectory(obs_T_Do, obsfeat_T_Df, adist_T_Pa, a_T_Da, r_T)
开发者ID:1769948908,项目名称:imitation,代码行数:7,代码来源:__init__.py
示例11: get_non_missing
def get_non_missing(ids, x, y, real_codes):
"""
Takes lists of the data and removes missing data!
:param ids:
:param x:
:param y:
:param real_codes:
:return:
"""
dataset = zip(ids, x, y, real_codes)
dataset = np.array(dataset, dtype=object)
non_miss = dataset[~(dataset[:,3] == '""')]
id_clean = non_miss[:,0].tolist() ##Takes first column of non_missing matrix to writes it to a list
text_clean = non_miss[:,1]
code_clean = non_miss[:,2]
real_codes_clean = non_miss[:,3].tolist()
real_codes_clean = [float(i) for i in real_codes_clean] ##Turns real_codes into floats for memory efficiency
real_codes_clean = np.array(real_codes_clean)
text_clean = np.stack(text_clean, axis=0) ## Makes everything a 2D array instead of array of arrays...
code_clean = np.stack(code_clean, axis=0)
return [id_clean, text_clean, code_clean, real_codes_clean]
开发者ID:AdamHede,项目名称:text_cnn,代码行数:25,代码来源:data_importer_mine2.py
示例12: read
def read(self, input_path):
'''
Reads in the data from input files
'''
self.lr_inputs = None
self.sr_outputs = None
print(input_path)
filenames = glob.glob(input_path + '*')
#TODO: remove assertion
assert len(filenames) > 0
random.shuffle(filenames)
filenames = filenames[0:150]
print('Length: ' + str(len(filenames)))
filenames.sort()
outputs = []
inputs = []
for filename in filenames:
output_img = cv2.imread(filename)
# Asserts the image is read correctly and not empty
assert output_img.shape[0] > 0
assert output_img.shape[1] > 0
#TODO: read in actual depth
output_depth = np.random.random((output_img.shape[0], output_img.shape[1], 1))
#print(type(output_img))
output_img = np.concatenate((output_img, output_depth), 2)
#print(type(output_img))
outputs.append(output_img)
input_img = compute_lr_input(
output_img, downsampling_factor_x=2,
downsampling_factor_y=2, blur_sigma=1.6, noise_sigma=0.03)
inputs.append(input_img)
self.sr_outputs = np.stack(outputs, axis=0)
self.lr_inputs = np.stack(inputs, axis=0)
开发者ID:seb-kro,项目名称:RGB-D_SR_Net,代码行数:33,代码来源:dataset.py
示例13: step
def step(self, action):
# x = np.argmax(action[:image_width])
# r = (np.argmax(action[image_width:]) - 1)
# pic = self.canvas[:, :, 0]
# if (r != -1):
# r = 2 ** r
# for i in range(image_width):
# if(np.sum(pic[i, x : x + r + 1])):
# self.draw(x, i, r)
# break
x = (action[:image_width] + 1) / 2.
y = (action[image_width:] + 1) / 2.
grey = x * y.reshape(image_width, 1)
grey = grey.reshape((image_width, image_width, 1))
grey = (grey * (255, 255, 255) / 4).astype('uint8')
grey = np.minimum(grey, self.canvas)
self.canvas -= grey
diff = self.diff()
reward = (self.lastdiff - diff) / self.rewardscale # reward is positive if diff increased
self.lastdiff = diff
self.stepnum += 1
ob = self.observation()
self.canvas = np.stack(np.rot90(self.canvas))
self.target = np.stack(np.rot90(self.target))
self.time += 1. / max_step
return ob, reward, (self.stepnum >= max_step), None # o,r,d,i
开发者ID:megvii-rl,项目名称:pytorch-gym,代码行数:26,代码来源:env.py
示例14: main
def main(args):
# load the model
model = load_model(args.model_filename, custom_objects={
'SubPixelUpscaling': SubPixelUpscaling
})
print model.layers
# load the images and bucket them by shape
images_by_size = defaultdict(list)
for filename in glob.glob(args.image_glob):
img = Image.open(filename)
img = img.resize(map(lambda x: int(x * args.output_scale), img.size)) # scale up
images_by_size[img.size].append(img)
# apply the model to the images
for size, imgs in images_by_size.items():
images = map(img_to_array, imgs)
images = (np.stack(images) / 127.5) - 1.
# NOTE: :(
x = input_layer = Input(shape=images.shape[1:])
for layer in model.layers[1:]:
x = layer(x)
this_model = Model([input_layer], [x])
this_model.compile(optimizer='sgd', loss='mse')
# END :(
new_images = images
for _ in range(args.apply_n):
new_images = this_model.predict(new_images, verbose=False)
# save before/after images
for i in range(new_images.shape[0]):
new_image = new_images[i]
image = images[i]
samples = np.stack([image, new_image])
filename = '{}_{}.png'.format(size, i)
filename = os.path.join(args.output_path, filename)
print('saving sample', samples.shape, filename)
save_sample_grid(samples, filename)
开发者ID:awentzonline,项目名称:bob-loss,代码行数:35,代码来源:apply_fc.py
示例15: calc_score
def calc_score(self):
cardtype_names = np.array(
['highcard', 'pair', 'twopair', 'threeofakind', 'straight', 'flush', 'fullhouse', 'fourofakind',
'straightflush'])
self.cardtype_multiplier = np.array(
[self.highcard_multiplier, self.pair_multiplier, self.twopair_multiplier, self.threeofakind_multiplier,
self.straight_multiplier, self.flush_multiplier, self.fullhouse_multiplier, self.fourofakind_multiplier,
self.straighflush_multiplier])
self.detected_types = np.stack((self.highcard, self.pair, self.twopair, self.threeofakind,
self.straight, self.flush, self.fullhouse, self.fourofakind,
self.straightflush), axis=0)
self.hand_vals = np.stack((self.highCardsVal, self.pairScore, self.twoPairScore, self.threeScore,self.straightScore,
self.flushScore,self.fullhouseScore,self.fourofakindScore,self.straightflush_score),axis = 0)
detected_types = self.detected_types * 1
self.active_multiplier = self.cardtype_multiplier[:,None,None] * detected_types * self.hand_vals
self.ordered_multiplier = np.sort(self.active_multiplier,axis = 0)[::-1,:,:]
highestVals = np.argmax(self.ordered_multiplier[0,:,:], axis=1)
Winners = (self.ordered_multiplier[0, ::] == np.amax(self.ordered_multiplier[0, :, :], axis=1)[:, None])
MyWinnerMask = np.zeros(self.player_amount, dtype=int)
MyWinnerMask[0] = 1
MyWinnArray = (Winners == MyWinnerMask).all(1)
MyWins = np.sum(MyWinnArray,axis = 0)
# print('cardtype_multiplier \n {}'.format(self.cardtype_multiplier))
# print('detected_types \n {}'.format(detected_types))
# print('hand_vals \n {}'.format(self.hand_vals))
# print('active_multiplier \n {}'.format(self.active_multiplier))
# print('ordered_multiplier \n {}'.format(self.ordered_multiplier))
# print('highest vals \n {}'.format(highestVals))
# print('My Wins \n {}'.format(MyWins))
return MyWins / self.iterations
开发者ID:dickreuter,项目名称:Poker,代码行数:35,代码来源:montecarlo_numpy2.py
示例16: add_polygons
def add_polygons(self, polygons, y_offset, x_offset, dimensions):
'''Creates a label image representation of segmented objects based
on global map coordinates of object contours.
Parameters
----------
polygons: List[List[Tuple[Union[int, shapely.geometry.polygon.Polygon]]]]
label and polygon geometry for segmented objects at each z-plane
and time point
y_offset: int
global vertical offset that needs to be subtracted from
y-coordinates
x_offset: int
global horizontal offset that needs to be subtracted from
x-coordinates
dimensions: Tuple[int]
*y*, *x* dimensions of image pixel planes
Returns
-------
numpy.ndarray[numpy.int32]
label image
'''
zstacks = list()
for poly in polygons:
zplanes = list()
for p in poly:
image = SegmentationImage.create_from_polygons(
p, y_offset, x_offset, dimensions
)
zplanes.append(image.array)
array = np.stack(zplanes, axis=-1)
zstacks.append(array)
self.value = np.stack(zstacks, axis=-1)
return self.value
开发者ID:dvischi,项目名称:TissueMAPS,代码行数:35,代码来源:handles.py
示例17: corrcoef_raftscope
def corrcoef_raftscope(raftsfits, ROIrows, ROIcols, norm=True):
"""
Correlation over one or more CCDs, calculating correlation along lines at each time index in ROIcols,
then averaging.
:param raftsfits: file list
:param ROIrows: must be in the format: slice(start, stop)
:param ROIcols: must be in the format: slice(start, stop)
:param norm: if True, computes correlation coefficients; if not, returns covariances
:return:
"""
stackh = []
for fl in raftsfits:
h = pyfits.open(fl)
for i in range(1, 17):
stackh.append(h[i].data[ROIrows, ROIcols])
h.close()
del h
stackh = np.stack(stackh)
print stackh.shape
a = []
for numcol in range(ROIcols.stop - ROIcols.start):
if norm:
a.append(np.corrcoef(stackh[:, :, numcol]))
else:
a.append(np.cov(stackh[:, :, numcol]))
a = np.stack(a)
print a.shape
return a.mean(axis=0)
开发者ID:lsst-camera-dh,项目名称:harnessed-jobs,代码行数:32,代码来源:multiscope.py
示例18: extract_features
def extract_features(ids, path, output_path, extractor, batch_size=64):
images_names = dict()
for p in listdir(path):
image_id = int(p.split('_')[-1].split('.')[0])
if image_id in ids:
images_names[image_id] = p
batch,names = [],[]
with open(output_path,'w') as output_file:
for idx,n in enumerate(images_names):
p = join(path, images_names[n])
batch.append(load_image(p))
names.append(n)
if len(batch)==batch_size:
batch = np.stack(batch)
feed_dict = {images: batch}
with tf.device('/gpu:0'):
features = sess.run(extractor, feed_dict=feed_dict)
for n,f in zip(names,features):
output_file.write("%s;%s\n" % (n, " ".join(str(x) for x in f)))
print("%d/%d" % (idx,len(images_names)))
batch, names = [],[]
output_file.flush()
if len(batch)>0:
batch = np.stack(batch)
feed_dict = {images: batch}
with tf.device('/gpu:0'):
features = sess.run(extractor, feed_dict=feed_dict)
for n,f in zip(names,features):
output_file.write("%s;%s\n" % (n, " ".join(str(x) for x in f)))
print("%d/%d" % (idx,len(images_names)))
output_file.flush()
开发者ID:Hediby,项目名称:vanilla_vqa,代码行数:31,代码来源:extract_features_cocoqa.py
示例19: test_arrayize_vectorized_indexer
def test_arrayize_vectorized_indexer(self):
for i, j, k in itertools.product(self.indexers, repeat=3):
vindex = indexing.VectorizedIndexer((i, j, k))
vindex_array = indexing._arrayize_vectorized_indexer(
vindex, self.data.shape)
np.testing.assert_array_equal(
self.data[vindex], self.data[vindex_array],)
actual = indexing._arrayize_vectorized_indexer(
indexing.VectorizedIndexer((slice(None),)), shape=(5,))
np.testing.assert_array_equal(actual.tuple, [np.arange(5)])
actual = indexing._arrayize_vectorized_indexer(
indexing.VectorizedIndexer((np.arange(5),) * 3), shape=(8, 10, 12))
expected = np.stack([np.arange(5)] * 3)
np.testing.assert_array_equal(np.stack(actual.tuple), expected)
actual = indexing._arrayize_vectorized_indexer(
indexing.VectorizedIndexer((np.arange(5), slice(None))),
shape=(8, 10))
a, b = actual.tuple
np.testing.assert_array_equal(a, np.arange(5)[:, np.newaxis])
np.testing.assert_array_equal(b, np.arange(10)[np.newaxis, :])
actual = indexing._arrayize_vectorized_indexer(
indexing.VectorizedIndexer((slice(None), np.arange(5))),
shape=(8, 10))
a, b = actual.tuple
np.testing.assert_array_equal(a, np.arange(8)[np.newaxis, :])
np.testing.assert_array_equal(b, np.arange(5)[:, np.newaxis])
开发者ID:crusaderky,项目名称:xarray,代码行数:30,代码来源:test_indexing.py
示例20: _feed_dict
def _feed_dict(self, train_batch, is_training=True):
pred_polys = train_batch['raw_polys'] * np.expand_dims(train_batch['masks'], axis=2) # (seq,batch,2)
pred_polys = np.transpose(pred_polys, [1, 0, 2]) # (batch,seq,2)
pred_mask = np.transpose(train_batch['masks'], [1, 0]) # (batch_size,seq_len)
cnn_feats = train_batch['cnn_feats'] # (batch_size, 28, 28, 128)
cells_1 = np.stack([np.split(train_batch['hiddens_list'][-1][0], 2, axis=3)[0]], axis=1)
cells_2 = np.stack([np.split(train_batch['hiddens_list'][-1][1], 2, axis=3)[0]], axis=1)
pred_mask_imgs = self.draw_mask(28, 28, pred_polys, pred_mask)
if is_training:
raise NotImplementedError()
r = {
self._ph.cells_1: cells_1,
self._ph.cells_2: cells_2,
self._ph.pred_mask_imgs: pred_mask_imgs,
self._ph.cnn_feats: cnn_feats,
self._ph.predicted_mask: pred_mask,
self._ph.pred_polys: pred_polys,
self._ph.ious: self._zero_batch
}
return r
开发者ID:zhouleidcc,项目名称:polyrnn-pp,代码行数:28,代码来源:EvalNet.py
注:本文中的numpy.stack函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论