本文整理汇总了Python中numpy.where函数的典型用法代码示例。如果您正苦于以下问题:Python where函数的具体用法?Python where怎么用?Python where使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了where函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: test_float_modulus_exact
def test_float_modulus_exact(self):
# test that float results are exact for small integers. This also
# holds for the same integers scaled by powers of two.
nlst = list(range(-127, 0))
plst = list(range(1, 128))
dividend = nlst + [0] + plst
divisor = nlst + plst
arg = list(itertools.product(dividend, divisor))
tgt = list(divmod(*t) for t in arg)
a, b = np.array(arg, dtype=int).T
# convert exact integer results from Python to float so that
# signed zero can be used, it is checked.
tgtdiv, tgtrem = np.array(tgt, dtype=float).T
tgtdiv = np.where((tgtdiv == 0.0) & ((b < 0) ^ (a < 0)), -0.0, tgtdiv)
tgtrem = np.where((tgtrem == 0.0) & (b < 0), -0.0, tgtrem)
for dt in np.typecodes['Float']:
msg = 'dtype: %s' % (dt,)
fa = a.astype(dt)
fb = b.astype(dt)
# use list comprehension so a_ and b_ are scalars
div = [self.floordiv(a_, b_) for a_, b_ in zip(fa, fb)]
rem = [self.mod(a_, b_) for a_, b_ in zip(fa, fb)]
assert_equal(div, tgtdiv, err_msg=msg)
assert_equal(rem, tgtrem, err_msg=msg)
开发者ID:birm,项目名称:numpy,代码行数:26,代码来源:test_scalarmath.py
示例2: soft_threshold
def soft_threshold(lamda,b):
th = float(lamda)/2.0
print ("(lamda,Threshold)",lamda,th)
print("The type of b is ..., its len is ",type(b),b.shape,len(b[0]))
if(lamda == 0):
return b
m,n = b.shape
x = np.zeros((m,n))
k = np.where(b > th)
# print("(b > th)",k)
#print("Number of elements -->(b > th) ",type(k))
x[k] = b[k] - th
k = np.where(np.absolute(b) <= th)
# print("abs(b) <= th",k)
# print("Number of elements -->abs(b) <= th ",len(k))
x[k] = 0
k = np.where(b < -th )
# print("(b < -th )",k)
# print("Number of elements -->(b < -th ) <= th",len(k))
x[k] = b[k] + th
x = x[:]
return x
开发者ID:ykwon0407,项目名称:rcae,代码行数:29,代码来源:section_5.3_image_denoising_results_script_CAE.py
示例3: plotResult
def plotResult(self, nn):
cmask = np.where(self.y==1);
plot(self.X[cmask,0], self.X[cmask,1], 'or', markersize=4)
cmask = np.where(self.y==2);
plot(self.X[cmask,0], self.X[cmask,1], 'ob', markersize=4)
cmask = np.where(self.y==3);
plot(self.X[cmask,0], self.X[cmask,1], 'og', markersize=4)
minX = min(self.X[:,0])
minY = min(self.X[:,1])
maxX = max(self.X[:,0])
maxY = max(self.X[:,1])
grid_range = [minX, maxX, minY, maxY];
delta = 0.05; levels = 100
a = arange(grid_range[0],grid_range[1],delta)
b = arange(grid_range[2],grid_range[3],delta)
A, B = meshgrid(a, b)
values = np.zeros(A.shape)
for i in range(len(a)):
for j in range(len(b)):
values[j,i] = nn.getNetworkOutput( [ a[i], b[j] ] )
contour(A, B, values, levels=[1], colors=['k'], linestyles='dashed')
contourf(A, B, values, levels=linspace(values.min(),values.max(),levels), cmap=cm.RdBu)
开发者ID:dzitkowskik,项目名称:NeuralNetwork-pybrain,代码行数:25,代码来源:csv_data.py
示例4: infos
def infos() :
'''Some information about the neuronal populations for each structures'''
print "Striatal populations: %d" % len(STR)
print "Pallidal populations: %d" % len(GP)
print
C = (W_STR_STR > 0).sum(axis=1) + (W_STR_STR < 0).sum(axis=1)
L = W_STR_STR[np.where(W_STR_STR != 0)]
print "Collateral striatal connections"
print "Mean number: %g (+/- %g)" % (C.mean(), C.std())
print "Mean length: %g (+/- %g)" % (L.mean(), L.std())
print
C = (W_GP_GP > 0).sum(axis=1) + (W_GP_GP < 0).sum(axis=1)
L = W_GP_GP[np.where(W_GP_GP != 0)]
print "Collateral pallidal connections"
print "Mean number: %g (+/- %g)" % (C.mean(), C.std())
print "Mean length: %g (+/- %g)" % (L.mean(), L.std())
print
C = (W_STR_GP > 0).sum(axis=1) + (W_STR_GP < 0).sum(axis=1)
L = W_STR_GP[np.where(W_STR_GP != 0)]
print "Striato-pallidal connections"
print "Mean number: %g (+/- %g)" % (C.mean(), C.std())
print "Mean length: %g (+/- %g)" % (L.mean(), L.std())
print
print "Mean # collateral striato-pallidal connections: %g (+/- %g)" % (C.mean(), C.std())
C = (W_GP_STR > 0).sum(axis=1) + (W_GP_STR < 0).sum(axis=1)
L = W_GP_STR[np.where(W_GP_STR != 0)]
print "Pallido-striatal connections"
print "Mean number: %g (+/- %g)" % (C.mean(), C.std())
print "Mean length: %g (+/- %g)" % (L.mean(), L.std())
print
开发者ID:Aurelien1609,项目名称:Computational-model,代码行数:35,代码来源:frequency_model.py
示例5: compute_cost
def compute_cost( X, y, theta, lam ):
'''Compute cost for logistic regression.'''
# Number of training examples
m = y.shape[0]
# Compute the prediction based on theta and X
predictions = X.dot( theta )
# Preprocessing values before sending to sigmoid function.
# If the argument to sigmoid function >= 0, we know that the
# sigmoid value is 1. Similarly for the negative values.
predictions[ where( predictions >= 20 ) ] = 20
predictions[ where( predictions <= -500 ) ] = -500
hypothesis = sigmoid( predictions )
hypothesis[ where( hypothesis == 1.0 ) ] = 0.99999
# Part of the cost function without regularization
J1 = ( -1.0 / m ) * sum( ( y * np.log( hypothesis ) ) +
( ( 1.0 - y ) * np.log( 1.0 - hypothesis ) ) )
# Computing the regularization term
J2 = lam / ( 2.0 * m ) * sum( theta[ 1:, ] * theta[ 1:, ] )
error = hypothesis - y
return J1 + J2
开发者ID:racheesingh,项目名称:py-birth-weight-prediction,代码行数:28,代码来源:LogisticRegression.py
示例6: pixel_to_prime
def pixel_to_prime(self, x, y, color=0):
# Secret decoder ring:
# http://www.sdss.org/dr7/products/general/astrometry.html
# (color)0 is called riCut;
# g0, g1, g2, and g3 are called
# dRow0, dRow1, dRow2, and dRow3, respectively;
# h0, h1, h2, and h3 are called
# dCol0, dCol1, dCol2, and dCol3, respectively;
# px and py are called csRow and csCol, respectively;
# and qx and qy are called ccRow and ccCol, respectively.
color0 = self._get_ricut()
g0, g1, g2, g3 = self._get_drow()
h0, h1, h2, h3 = self._get_dcol()
px, py, qx, qy = self._get_cscc()
# #$(%*&^(%$%*& bad documentation.
(px,py) = (py,px)
(qx,qy) = (qy,qx)
yprime = y + g0 + g1 * x + g2 * x**2 + g3 * x**3
xprime = x + h0 + h1 * x + h2 * x**2 + h3 * x**3
# The code below implements this, vectorized:
# if color < color0:
# xprime += px * color
# yprime += py * color
# else:
# xprime += qx
# yprime += qy
qx = qx * np.ones_like(x)
qy = qy * np.ones_like(y)
xprime += np.where(color < color0, px * color, qx)
yprime += np.where(color < color0, py * color, qy)
return (xprime, yprime)
开发者ID:joshuawallace,项目名称:astrometry.net,代码行数:35,代码来源:common.py
示例7: processTrafficData
def processTrafficData(self):
for index, row in self.traffic_data.iterrows():
adjacent_list = []
if index > 1 and index < 18467:
if self.traffic_data.ix[index - 1]['inter1'] == row[0]:
key1 = self.traffic_data.ix[index - 1]['inter1'] + ', ' + self.traffic_data.ix[index - 1]['inter2']
adjacent_list.append(key1)
if self.traffic_data.ix[index + 1]['inter1'] == row[0]:
key2 = self.traffic_data.ix[index + 1]['inter1'] + ', ' + self.traffic_data.ix[index + 1]['inter2']
adjacent_list.append(key2)
keylist = np.where(self.traffic_data['inter1'] == row[1])[0]
keylist1 = np.where(self.traffic_data['inter2'] == row[0])[0]
ind_list = np.intersect1d(keylist, keylist1)
if len(ind_list) >= 1:
ind = ind_list[0]
if ind > 1 and ind < 18467:
if self.traffic_data.ix[ind - 1]['inter1'] == row[1]:
key3 = self.traffic_data.ix[ind - 1]['inter1'] + ', ' + self.traffic_data.ix[ind - 1]['inter2']
adjacent_list.append(key3)
if self.traffic_data.ix[ind + 1]['inter1'] == row[1]:
key4 = self.traffic_data.ix[ind + 1]['inter1'] + ', ' + self.traffic_data.ix[ind + 1]['inter2']
adjacent_list.append(key4)
node = row['node']
node.setAdjacents(adjacent_list)
开发者ID:jungsw,项目名称:cs221,代码行数:31,代码来源:processor.py
示例8: word2ind
def word2ind(word, vocab, utok_ind=None):
ind = np.where(vocab == unicode(word))
if len(ind[0]) == 0:
if not utok_ind:
utok_ind = np.where(vocab == UTOK)
ind = utok_ind
return ind[0][0]
开发者ID:nturusin,项目名称:allenchallenge,代码行数:7,代码来源:AAI_lasagne_MemNN_7.py
示例9: test_background_model
def test_background_model(tmpdir):
data_store = DataStore.from_dir('$GAMMAPY_EXTRA/datasets/hess-crab4-hd-hap-prod2/')
bgmaker = OffDataBackgroundMaker(data_store, outdir=str(tmpdir))
bgmaker.select_observations(selection='all')
table = Table.read('run.lis', format='ascii.csv')
assert table['OBS_ID'][1] == 23526
bgmaker.group_observations()
table = ObservationTable.read(str(tmpdir / 'obs.fits'))
assert list(table['GROUP_ID']) == [0, 0, 0, 1]
table = ObservationTable.read(str(tmpdir / 'group-def.fits'))
assert list(table['ZEN_PNT_MAX']) == [49, 90]
# TODO: Fix 3D code
# bgmaker.make_model("3D")
# bgmaker.save_models("3D")
# model = CubeBackgroundModel.read(str(tmpdir / 'background_3D_group_001_table.fits.gz'))
# assert model.counts_cube.data.sum() == 1527
bgmaker.make_model("2D")
bgmaker.save_models("2D")
model = EnergyOffsetBackgroundModel.read(str(tmpdir / 'background_2D_group_001_table.fits.gz'))
assert model.counts.data.value.sum() == 1398
index_table_new = bgmaker.make_total_index_table(data_store, "2D", None, None)
table_bkg = index_table_new[np.where(index_table_new["HDU_NAME"] == "bkg_2d")]
name_bkg_run023523 = table_bkg[np.where(table_bkg["OBS_ID"] == 23523)]["FILE_NAME"]
assert str(tmpdir) + "/" + name_bkg_run023523[0] == str(tmpdir) + '/background_2D_group_001_table.fits.gz'
name_bkg_run023526 = table_bkg[np.where(table_bkg["OBS_ID"] == 23526)]["FILE_NAME"]
assert str(tmpdir) + "/" + name_bkg_run023526[0] == str(tmpdir) + '/background_2D_group_000_table.fits.gz'
开发者ID:astrofrog,项目名称:gammapy,代码行数:31,代码来源:test_off_data_background_maker.py
示例10: _lininterp
def _lininterp(self,x,X,Y):
if hasattr(x,'__len__'):
xtype = 'array'
xx=np.asarray(x).astype(np.float)
else:
xtype = 'scalar'
xx=np.asarray([x]).astype(np.float)
idx = X.searchsorted(xx)
yy = xx*0
yy[idx>len(X)-1] = Y[-1] # over
yy[idx<=0] = Y[0] # under
wok = np.where((idx>0) & (idx<len(X))) # the good ones
iok=idx[wok]
yywok = Y[iok-1] + ( (Y[iok]-Y[iok-1])/(X[iok]-X[iok-1])
* (xx[wok]-X[iok-1]) )
w = np.where( ((X[iok]-X[iok-1]) == 0) ) # where are the nan ?
yywok[w] = Y[iok[w]-1] # replace by previous value
wl = np.where(xx[wok] == X[0])
yywok[wl] = Y[0]
wh = np.where(xx[wok] == X[-1])
yywok[wh] = Y[-1]
yy[wok] = yywok
if xtype == 'scalar':
yy = yy[0]
return yy
开发者ID:montefra,项目名称:healpy,代码行数:25,代码来源:projaxes.py
示例11: corrtag_image
def corrtag_image(in_data,xtype='XCORR',ytype='YCORR',pha=(2,30),bins=(1024,16384),times=None,ranges=((0,1023),(0,16384)),binning=(1,1),NUV=False):
try: histogram2d
except NameError: from numpy import histogram2d,where,zeros
try: getdata
except NameError: from pyfits import getdata
try: events=getdata(in_data,1)
except: events=in_data
xlength = (ranges[1][1]+1)
ylength = (ranges[0][1]+1)
xbinning = binning[1]
ybinning = binning[0]
if NUV:
bins = (1024,1024)
pha = (-1,1)
ranges = ( (0,1023), (0,1023) )
if times != None:
index = where( (events['TIME']>=times[0]) & (events['TIME'] <= times[1]) )
events= events[index]
index = where((events['PHA']>=pha[0])&(events['PHA']<=pha[1]))
if len(index[0]):
image,y_r,x_r = histogram2d(events[ytype][index],events[xtype][index],bins=bins,range=ranges)
else:
image = zeros( (bins[0]//binning[0],bins[1]//binning[1]) )
return image
开发者ID:justincely,项目名称:cos_monitoring,代码行数:30,代码来源:utils.py
示例12: __init__
def __init__(self,turn,elem,single,name,s,x,xp,y,yp,pc,de,tau,**args):
apc=float(pc[0])*1e9
ade=float(de[0])
self.m0=self.pmass
en=np.sqrt(apc**2+self.pmass**2)
self.e0=en-ade
self.p0c=np.sqrt(self.e0**2-self.m0**2)
# structure
self.elem=np.array(elem,dtype=int)
self.turn=np.array(turn,dtype=int)
d0=np.where(np.diff(self.elem)!=0)[0][0]+1
d1=(np.where(np.diff(self.turn)!=0)[0][0]+1)/d0
d2=len(self.turn)/d1/d0
self.single=np.array(single,dtype=int)
self.name=np.array(name,dtype=str)
self.s =np.array(s ,dtype=float)
self.x =np.array(x ,dtype=float)
self.y =np.array(y ,dtype=float)
self.tau=-np.array(tau,dtype=float)*self.clight
opd=np.array(pc,dtype=float)*(1e9/self.p0c)
self.delta=opd-1
self.pt=np.array(de,dtype=float)/self.p0c
self.px=np.array(xp,dtype=float)*opd
self.py=np.array(yp,dtype=float)*opd
for nn,vv in self.__dict__.items():
if hasattr(vv,'__len__') and len(vv)==d0*d1*d2:
setattr(self,nn,vv.reshape(d2,d1,d0))
开发者ID:vrosnet,项目名称:SixTrackLib,代码行数:27,代码来源:sixdump.py
示例13: mutual_information
def mutual_information(self, x_index, y_index, log_base, debug=False):
"""
Calculate and return Mutual information between two random variables
"""
# Check if index are into the bounds
assert (0 <= x_index <= self.n_rows)
assert (0 <= y_index <= self.n_rows)
# Variable to return MI
summation = 0.0
# Get uniques values of random variables
values_x = set(self.data[x_index])
values_y = set(self.data[y_index])
# Print debug info
if debug:
print 'MI between'
print self.data[x_index]
print self.data[y_index]
# For each random
for value_x in values_x:
for value_y in values_y:
px = shape(where(self.data[x_index] == value_x))[1] / self.n_cols
py = shape(where(self.data[y_index] == value_y))[1] / self.n_cols
pxy = len(where(in1d(where(self.data[x_index] == value_x)[0],
where(self.data[y_index] == value_y)[0]) == True)[0]) / self.n_cols
if pxy > 0.0:
summation += pxy * math.log((pxy / (px * py)), log_base)
if debug:
print '(%d,%d) px:%f py:%f pxy:%f' % (value_x, value_y, px, py, pxy)
return summation
开发者ID:YukiShan,项目名称:amazon-review-spam,代码行数:29,代码来源:it_tools.py
示例14: entropy
def entropy(self, x_index, y_index, log_base, debug=False):
"""
Calculate the entropy between two random variable
"""
assert (0 <= x_index <= self.n_rows)
assert (0 <= y_index <= self.n_rows)
# Variable to return MI
summation = 0.0
# Get uniques values of random variables
values_x = set(self.data[x_index])
values_y = set(self.data[y_index])
# Print debug info
if debug:
print 'Entropy between'
print self.data[x_index]
print self.data[y_index]
# For each random
for value_x in values_x:
for value_y in values_y:
pxy = len(where(in1d(where(self.data[x_index] == value_x)[0],
where(self.data[y_index] == value_y)[0]) == True)[0]) / self.n_cols
if pxy > 0.0:
summation += pxy * math.log(pxy, log_base)
if debug:
print '(%d,%d) pxy:%f' % (value_x, value_y, pxy)
if summation == 0.0:
return summation
else:
return - summation
开发者ID:YukiShan,项目名称:amazon-review-spam,代码行数:29,代码来源:it_tools.py
示例15: myTradingSystem
def myTradingSystem(DATE, CLOSE, settings):
''' This system uses mean reversion techniques to allocate capital into the desired equities '''
# This strategy evaluates two averages over time of the close over a long/short
# scale and builds the ratio. For each day, "smaQuot" is an array of "nMarkets"
# size.
nMarkets = numpy.shape(CLOSE)[1]
periodLong = 200
periodShort = 40
smaLong = numpy.sum(CLOSE[-periodLong:, :], axis=0)/periodLong
smaRecent = numpy.sum(CLOSE[-periodShort:, :], axis=0)/periodShort
smaQuot = smaRecent / smaLong
# For each day, scan the ratio of moving averages over the markets and find the
# market with the maximum ratio and the market with the minimum ratio:
longEquity = numpy.where(smaQuot == numpy.nanmin(smaQuot))
shortEquity = numpy.where(smaQuot == numpy.nanmax(smaQuot))
# Take a contrarian view, going long the market with the minimum ratio and
# going short the market with the maximum ratio. The array "pos" will contain
# all zero entries except for those cases where we go long (1) and short (-1):
pos = numpy.zeros((1, nMarkets))
pos[0, longEquity[0][0]] = 1
pos[0, shortEquity[0][0]] = -1
# For the position sizing, we supply a vector of weights defining our
# exposure to the markets in settings['markets']. This vector should be
# normalized.
pos = pos/numpy.nansum(abs(pos))
return pos, settings
开发者ID:Quantiacs,项目名称:python-sample-strategies,代码行数:32,代码来源:meanReversion.py
示例16: implicit_black_box
def implicit_black_box(propensities, V, X, w, h, deter_vector, stoc_positions, positions, valid, deriv):
# Adjustment for systems reaching steady state
temp = derivative_G(propensities, V, X, w, deter_vector, stoc_positions, positions, valid)
# pdb.set_trace()
valid_adjust_pos = np.where(np.sum(np.abs(temp), axis=0) < 1e-10, True, False)
valid_adjust = valid[:, :]
valid_adjust[valid_adjust_pos, :] = False
# print(" Reached Steady State %d"%(np.sum(valid_adjust_pos)))
from scipy.integrate import ode
# pdb.set_trace()
deter_ode = ode(f).set_integrator("lsoda", method="adams", with_jacobian=False)
deter_ode.set_initial_value(X[deter_vector, :].flatten(), 0).set_f_params(
[propensities, V, X, deter_vector, stoc_positions, positions, valid_adjust, w]
)
# pdb.set_trace()
while deter_ode.successful() and deter_ode.t < h:
deter_ode.integrate(h)
# print("Black Box: \n"+ str(deter_ode.y))
# print("iterator : \n:"+str(next_X[deter_vector,:]))
X[deter_vector, :] = deter_ode.y.reshape((np.sum(deter_vector), X.shape[1]))
# Another adjust to compensate for non negative
X = np.where(X < 0.0, 0.0, X)
return X
开发者ID:vikramsunkara,项目名称:PyME,代码行数:34,代码来源:implicit_ODE.py
示例17: stats
def stats(t, snp=None):
'''Return a record array with imputation statistics.'''
T = t.sample_index_to_impute
imputed = t.imputed_data[:, T, :]
tot_to_impute = 2 * imputed.shape[1]
snp = snp if snp is not None else np.arange(t.num_snps)
stats = np.zeros((len(snp),),
dtype=[
('dist_cm', 'f4'), # Genetic distance from beginning of chromosome
('count', '(2,)i4'), # Allele count
('frequency', '(2,)f4'), # Allele frequency
('call_rate', 'f4'), # Imputation Call rate
('call_rate_training', 'f4') # Imputation Call rate
])
call_rate_training = 1.0 * np.sum(np.sum(t.imputed_data[:, t.sample_index, :] != 0, axis=2), axis=1)# / (2 * len(t.sample_index))
for row, snp_index in enumerate(snp):
# TODO: replace by a bulk group-by/hist?
# g = t.__t.training_data[snp_index, :, :]
i = imputed[snp_index, :]
(c1, c2) = (len(np.where(i == 1)[0]), len(np.where(i == 2)[0]))
c = c1 + c2 + SMALL_FLOAT
f1, f2 = (1.0 * c1) / c, (1.0 * c2) / c
call_rate = 1.0 * len(i.nonzero()[0]) / tot_to_impute
# print 'c1 %4d c2 %4d f1 %.2f f2 %.2f call rate %5.2f' % (c1, c2, f1, f2, call_rate)
stats[row] = (t.snp['dist_cm'][snp_index], [c1, c2], [f1, f2], call_rate, call_rate_training[snp_index])
return stats
开发者ID:orenlivne,项目名称:ober,代码行数:26,代码来源:impute_stats_rare.py
示例18: CalcErange
def CalcErange(inWS,ns,erange,binWidth):
#length of array in Fortran
array_len = 4096
binWidth = int(binWidth)
bnorm = 1.0/binWidth
#get data from input workspace
_,X,Y,E = GetXYE(inWS,ns,array_len)
Xdata = mtd[inWS].readX(0)
#get all x values within the energy range
rangeMask = (Xdata >= erange[0]) & (Xdata <= erange[1])
Xin = Xdata[rangeMask]
#get indicies of the bounds of our energy range
minIndex = np.where(Xdata==Xin[0])[0][0]+1
maxIndex = np.where(Xdata==Xin[-1])[0][0]
#reshape array into sublists of bins
Xin = Xin.reshape(len(Xin)/binWidth, binWidth)
#sum and normalise values in bins
Xout = [sum(bin_val) * bnorm for bin_val in Xin]
#count number of bins
nbins = len(Xout)
nout = [nbins, minIndex, maxIndex]
#pad array for use in Fortran code
Xout = PadArray(Xout,array_len)
return nout,bnorm,Xout,X,Y,E
开发者ID:rosswhitfield,项目名称:mantid,代码行数:34,代码来源:IndirectBayes.py
示例19: prime_to_pixel
def prime_to_pixel(self, xprime, yprime, color=0):
color0 = self._get_ricut()
g0, g1, g2, g3 = self._get_drow()
h0, h1, h2, h3 = self._get_dcol()
px, py, qx, qy = self._get_cscc()
# #$(%*&^(%$%*& bad documentation.
(px,py) = (py,px)
(qx,qy) = (qy,qx)
qx = qx * np.ones_like(xprime)
qy = qy * np.ones_like(yprime)
xprime -= np.where(color < color0, px * color, qx)
yprime -= np.where(color < color0, py * color, qy)
# Now invert:
# yprime = y + g0 + g1 * x + g2 * x**2 + g3 * x**3
# xprime = x + h0 + h1 * x + h2 * x**2 + h3 * x**3
x = xprime - h0
# dumb-ass Newton's method
dx = 1.
# FIXME -- should just update the ones that aren't zero
# FIXME -- should put in some failsafe...
while np.max(np.abs(np.atleast_1d(dx))) > 1e-10:
xp = x + h0 + h1 * x + h2 * x**2 + h3 * x**3
dxpdx = 1 + h1 + h2 * 2*x + h3 * 3*x**2
dx = (xprime - xp) / dxpdx
x += dx
y = yprime - (g0 + g1 * x + g2 * x**2 + g3 * x**3)
return (x, y)
开发者ID:joshuawallace,项目名称:astrometry.net,代码行数:30,代码来源:common.py
示例20: multi_where
def multi_where(vec1, vec2):
'''Given two vectors, multi_where returns a tuple of indices where those
two vectors overlap.
****THIS FUNCTION HAS NOT BEEN TESTED ON N-DIMENSIONAL ARRAYS*******
Inputs:
2 numpy vectors
Output:
(xy, yx) where xy is a numpy vector containing the indices of the
elements in vector 1 that are also in vector 2. yx is a vector
containing the indices of the elements in vector 2 that are also
in vector 1.
Example:
>> x = np.array([1,2,3,4,5])
>> y = np.array([3,4,5,6,7])
>> (xy,yx) = multi_where(x,y)
>> xy
array([2,3,4])
>> yx
array([0,1,2])
'''
OneInTwo = np.array([])
TwoInOne = np.array([])
for i in range(vec1.shape[0]):
if np.where(vec2 == vec1[i])[0].shape[0]:
OneInTwo = np.append(OneInTwo,i)
TwoInOne = np.append(TwoInOne, np.where(vec2 == vec1[i])[0][0])
return (np.int8(OneInTwo), np.int8(TwoInOne))
开发者ID:eigenbrot,项目名称:snakes,代码行数:29,代码来源:ADEUtils.py
注:本文中的numpy.where函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论