本文整理汇总了Python中sklearn.neighbors.ball_tree.BallTree类的典型用法代码示例。如果您正苦于以下问题:Python BallTree类的具体用法?Python BallTree怎么用?Python BallTree使用的例子?那么恭喜您, 这里精选的类代码示例或许可以为您提供帮助。
在下文中一共展示了BallTree类的18个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: test_ball_tree_pickle
def test_ball_tree_pickle():
rng = check_random_state(0)
X = rng.random_sample((10, 3))
bt1 = BallTree(X, leaf_size=1)
# Test if BallTree with callable metric is picklable
bt1_pyfunc = BallTree(X, metric=dist_func, leaf_size=1, p=2)
ind1, dist1 = bt1.query(X)
ind1_pyfunc, dist1_pyfunc = bt1_pyfunc.query(X)
def check_pickle_protocol(protocol):
s = pickle.dumps(bt1, protocol=protocol)
bt2 = pickle.loads(s)
s_pyfunc = pickle.dumps(bt1_pyfunc, protocol=protocol)
bt2_pyfunc = pickle.loads(s_pyfunc)
ind2, dist2 = bt2.query(X)
ind2_pyfunc, dist2_pyfunc = bt2_pyfunc.query(X)
assert_array_almost_equal(ind1, ind2)
assert_array_almost_equal(dist1, dist2)
assert_array_almost_equal(ind1_pyfunc, ind2_pyfunc)
assert_array_almost_equal(dist1_pyfunc, dist2_pyfunc)
assert isinstance(bt2, BallTree)
for protocol in (0, 1, 2):
check_pickle_protocol(protocol)
开发者ID:allefpablo,项目名称:scikit-learn,代码行数:31,代码来源:test_ball_tree.py
示例2: similar_products2
def similar_products2(deep_f):
qs = Product.objects.all()
df=read_frame(qs)
df['idx'] = range(1, len(df) + 1)
feature_list=[]
asin_list=[]
for prod in qs:
feature_list.append(prod.get_features())
asin_list.append(prod.asin)
nparray = np.asarray(feature_list)
#print nparray
tree = BallTree(nparray)
dist, ind = tree.query(deep_f, k=5)
print ind
index = ind[0]
recom = index[0:]
recommended_asins =[];
for i in recom:
recommended_asins.append(asin_list[i])
recommended_prods = Product.objects.filter(asin__in = recommended_asins)
return recommended_prods
# image_train = graphlab.SFrame(data=df)
# cur_prod = image_train[18:19]
# print cur_prod
# print image_train
# knn_model = graphlab.nearest_neighbors.create(image_train, features = ['features'],label = 'asin',distance = 'levenshtein',method = 'ball_tree')
# knn_model.save('my_knn')
# #knn_model= graphlab.load_model('my_knn')
# #print knn_model.query(cur_prod)
# #knn_model = graphlab.nearest_neighbors.create(image_train, features = ['features'],label = 'keywords')
开发者ID:vatsalchanana,项目名称:image-search,代码行数:35,代码来源:views.py
示例3: test_ball_tree_pickle
def test_ball_tree_pickle():
np.random.seed(0)
X = np.random.random((10, 3))
bt1 = BallTree(X, leaf_size=1)
# Test if BallTree with callable metric is picklable
bt1_pyfunc = BallTree(X, metric=dist_func, leaf_size=1, p=2)
ind1, dist1 = bt1.query(X)
ind1_pyfunc, dist1_pyfunc = bt1_pyfunc.query(X)
def check_pickle_protocol(protocol):
s = pickle.dumps(bt1, protocol=protocol)
bt2 = pickle.loads(s)
s_pyfunc = pickle.dumps(bt1_pyfunc, protocol=protocol)
bt2_pyfunc = pickle.loads(s_pyfunc)
ind2, dist2 = bt2.query(X)
ind2_pyfunc, dist2_pyfunc = bt2_pyfunc.query(X)
assert_array_almost_equal(ind1, ind2)
assert_array_almost_equal(dist1, dist2)
assert_array_almost_equal(ind1_pyfunc, ind2_pyfunc)
assert_array_almost_equal(dist1_pyfunc, dist2_pyfunc)
for protocol in (0, 1, 2):
yield check_pickle_protocol, protocol
开发者ID:Afey,项目名称:scikit-learn,代码行数:29,代码来源:test_ball_tree.py
示例4: similar_products
def similar_products(product):
qs = Product.objects.all()
df=read_frame(qs)
df['idx'] = range(1, len(df) + 1)
feature_list=[]
asin_list=[]
product_index = 0
inn=0
for prod in qs:
feature_list.append(prod.get_features())
asin_list.append(prod.asin)
if prod.asin == product.asin:
product_index = inn
inn+=1
nparray = np.asarray(feature_list)
#print nparray
tree = BallTree(nparray)
dist, ind = tree.query(nparray[product_index], k=5)
print ind
index = ind[0]
recom = index[1:]
recommended_asins =[];
for i in recom:
recommended_asins.append(asin_list[i])
recommended_prods = Product.objects.filter(asin__in = recommended_asins)
return recommended_prods
开发者ID:vatsalchanana,项目名称:image-search,代码行数:28,代码来源:views.py
示例5: check_neighbors
def check_neighbors(dualtree, breadth_first, k, metric, kwargs):
bt = BallTree(X, leaf_size=1, metric=metric, **kwargs)
dist1, ind1 = bt.query(Y, k, dualtree=dualtree, breadth_first=breadth_first)
dist2, ind2 = brute_force_neighbors(X, Y, k, metric, **kwargs)
# don't check indices here: if there are any duplicate distances,
# the indices may not match. Distances should not have this problem.
assert_array_almost_equal(dist1, dist2)
开发者ID:albertotb,项目名称:scikit-learn,代码行数:8,代码来源:test_ball_tree.py
示例6: test_query_haversine
def test_query_haversine():
rng = check_random_state(0)
X = 2 * np.pi * rng.random_sample((40, 2))
bt = BallTree(X, leaf_size=1, metric='haversine')
dist1, ind1 = bt.query(X, k=5)
dist2, ind2 = brute_force_neighbors(X, X, k=5, metric='haversine')
assert_array_almost_equal(dist1, dist2)
assert_array_almost_equal(ind1, ind2)
开发者ID:BranYang,项目名称:scikit-learn,代码行数:9,代码来源:test_ball_tree.py
示例7: test_query_haversine
def test_query_haversine():
np.random.seed(0)
X = 2 * np.pi * np.random.random((40, 2))
bt = BallTree(X, leaf_size=1, metric='haversine')
dist1, ind1 = bt.query(X, k=5)
dist2, ind2 = brute_force_neighbors(X, X, k=5, metric='haversine')
assert_array_almost_equal(dist1, dist2)
assert_array_almost_equal(ind1, ind2)
开发者ID:Afey,项目名称:scikit-learn,代码行数:9,代码来源:test_ball_tree.py
示例8: test_ball_tree_kde
def test_ball_tree_kde(kernel, h, rtol, atol, breadth_first, n_samples=100,
n_features=3):
rng = np.random.RandomState(0)
X = rng.random_sample((n_samples, n_features))
Y = rng.random_sample((n_samples, n_features))
bt = BallTree(X, leaf_size=10)
dens_true = compute_kernel_slow(Y, X, kernel, h)
dens = bt.kernel_density(Y, h, atol=atol, rtol=rtol,
kernel=kernel,
breadth_first=breadth_first)
assert_allclose(dens, dens_true,
atol=atol, rtol=max(rtol, 1e-7))
开发者ID:allefpablo,项目名称:scikit-learn,代码行数:14,代码来源:test_ball_tree.py
示例9: test_gaussian_kde
def test_gaussian_kde(n_samples=1000):
# Compare gaussian KDE results to scipy.stats.gaussian_kde
from scipy.stats import gaussian_kde
rng = check_random_state(0)
x_in = rng.normal(0, 1, n_samples)
x_out = np.linspace(-5, 5, 30)
for h in [0.01, 0.1, 1]:
bt = BallTree(x_in[:, None])
gkde = gaussian_kde(x_in, bw_method=h / np.std(x_in))
dens_bt = bt.kernel_density(x_out[:, None], h) / n_samples
dens_gkde = gkde.evaluate(x_out)
assert_array_almost_equal(dens_bt, dens_gkde, decimal=3)
开发者ID:BranYang,项目名称:scikit-learn,代码行数:15,代码来源:test_ball_tree.py
示例10: test_ball_tree_query
def test_ball_tree_query(metric, k, dualtree, breadth_first):
rng = check_random_state(0)
X = rng.random_sample((40, DIMENSION))
Y = rng.random_sample((10, DIMENSION))
kwargs = METRICS[metric]
bt = BallTree(X, leaf_size=1, metric=metric, **kwargs)
dist1, ind1 = bt.query(Y, k, dualtree=dualtree,
breadth_first=breadth_first)
dist2, ind2 = brute_force_neighbors(X, Y, k, metric, **kwargs)
# don't check indices here: if there are any duplicate distances,
# the indices may not match. Distances should not have this problem.
assert_array_almost_equal(dist1, dist2)
开发者ID:allefpablo,项目名称:scikit-learn,代码行数:15,代码来源:test_ball_tree.py
示例11: test_ball_tree_query_metrics
def test_ball_tree_query_metrics(metric):
rng = check_random_state(0)
if metric in BOOLEAN_METRICS:
X = rng.random_sample((40, 10)).round(0)
Y = rng.random_sample((10, 10)).round(0)
elif metric in DISCRETE_METRICS:
X = (4 * rng.random_sample((40, 10))).round(0)
Y = (4 * rng.random_sample((10, 10))).round(0)
k = 5
bt = BallTree(X, leaf_size=1, metric=metric)
dist1, ind1 = bt.query(Y, k)
dist2, ind2 = brute_force_neighbors(X, Y, k, metric)
assert_array_almost_equal(dist1, dist2)
开发者ID:allefpablo,项目名称:scikit-learn,代码行数:15,代码来源:test_ball_tree.py
示例12: test_ball_tree_pickle
def test_ball_tree_pickle():
import pickle
np.random.seed(0)
X = np.random.random((10, 3))
bt1 = BallTree(X, leaf_size=1)
ind1, dist1 = bt1.query(X)
def check_pickle_protocol(protocol):
s = pickle.dumps(bt1, protocol=protocol)
bt2 = pickle.loads(s)
ind2, dist2 = bt2.query(X)
assert_allclose(ind1, ind2)
assert_allclose(dist1, dist2)
for protocol in (0, 1, 2):
yield check_pickle_protocol, protocol
开发者ID:kinnskogr,项目名称:scikit-learn,代码行数:16,代码来源:test_ball_tree.py
示例13: test_ball_tree_query_radius
def test_ball_tree_query_radius(n_samples=100, n_features=10):
np.random.seed(0)
X = 2 * np.random.random(size=(n_samples, n_features)) - 1
query_pt = np.zeros(n_features, dtype=float)
eps = 1E-15 # roundoff error can cause test to fail
bt = BallTree(X, leaf_size=5)
rad = np.sqrt(((X - query_pt) ** 2).sum(1))
for r in np.linspace(rad[0], rad[-1], 100):
ind = bt.query_radius(query_pt, r + eps)[0]
i = np.where(rad <= r + eps)[0]
ind.sort()
i.sort()
assert_array_almost_equal(i, ind)
开发者ID:Afey,项目名称:scikit-learn,代码行数:17,代码来源:test_ball_tree.py
示例14: test_gaussian_kde
def test_gaussian_kde(n_samples=1000):
"""Compare gaussian KDE results to scipy.stats.gaussian_kde"""
from scipy.stats import gaussian_kde
np.random.seed(0)
x_in = np.random.normal(0, 1, n_samples)
x_out = np.linspace(-5, 5, 30)
for h in [0.01, 0.1, 1]:
bt = BallTree(x_in[:, None])
try:
gkde = gaussian_kde(x_in, bw_method=h / np.std(x_in))
except TypeError:
raise SkipTest("Old version of scipy, doesn't accept explicit bandwidth.")
dens_bt = bt.kernel_density(x_out[:, None], h) / n_samples
dens_gkde = gkde.evaluate(x_out)
assert_array_almost_equal(dens_bt, dens_gkde, decimal=3)
开发者ID:99plus2,项目名称:scikit-learn,代码行数:18,代码来源:test_ball_tree.py
示例15: test_ball_tree_query_radius_distance
def test_ball_tree_query_radius_distance(n_samples=100, n_features=10):
np.random.seed(0)
X = 2 * np.random.random(size=(n_samples, n_features)) - 1
query_pt = np.zeros(n_features, dtype=float)
eps = 1E-15 # roundoff error can cause test to fail
bt = BallTree(X, leaf_size=5)
rad = np.sqrt(((X - query_pt) ** 2).sum(1))
for r in np.linspace(rad[0], rad[-1], 100):
ind, dist = bt.query_radius(query_pt, r + eps, return_distance=True)
ind = ind[0]
dist = dist[0]
d = np.sqrt(((query_pt - X[ind]) ** 2).sum(1))
assert_array_almost_equal(d, dist)
开发者ID:Afey,项目名称:scikit-learn,代码行数:18,代码来源:test_ball_tree.py
示例16: test_gaussian_kde
def test_gaussian_kde(n_samples=1000):
"""Compare gaussian KDE results to scipy.stats.gaussian_kde"""
from scipy.stats import gaussian_kde
np.random.seed(0)
x_in = np.random.normal(0, 1, n_samples)
x_out = np.linspace(-5, 5, 30)
for h in [0.01, 0.1, 1]:
bt = BallTree(x_in[:, None])
try:
gkde = gaussian_kde(x_in, bw_method=h / np.std(x_in))
except TypeError:
# older versions of scipy don't accept explicit bandwidth
raise SkipTest
dens_bt = bt.kernel_density(x_out[:, None], h) / n_samples
dens_gkde = gkde.evaluate(x_out)
assert_allclose(dens_bt, dens_gkde, rtol=1E-3, atol=1E-3)
开发者ID:kinnskogr,项目名称:scikit-learn,代码行数:19,代码来源:test_ball_tree.py
示例17: _nonlocalmeans_clustered
def _nonlocalmeans_clustered(img, n_small=5, n_components=9, n_neighbors=30, h=10):
Nw = (2 * n_small + 1) ** 2
h2 = h * h
n_rows, n_cols = img.shape
# precompute the coordinate difference for the big patch
small_rows, small_cols = np.indices(((2 * n_small + 1), (2 * n_small + 1))) - n_small
# put all patches so we can cluster them
n_padded = np.pad(img, n_small, mode='reflect')
patches = np.zeros((n_rows * n_cols, Nw))
n = 0
for r in range(n_small, n_small + n_rows):
for c in range(n_small, n_small + n_cols):
window = n_padded[r + small_rows, c + small_cols].flatten()
patches[n, :] = window
n += 1
transformed = PCA(n_components=n_components).fit_transform(patches)
# index the patches into a tree
tree = BallTree(transformed, leaf_size=2)
print("Denoising")
new_img = np.zeros_like(img)
for r in range(n_rows):
for c in range(n_cols):
idx = r * n_cols + c
dist, ind = tree.query(transformed[idx], k=n_neighbors)
ridx = np.array([(int(i / n_cols), int(i % n_cols)) for i in ind[0, 1:]])
colors = img[ridx[:, 0], ridx[:, 1]]
w = np.exp(-dist[0, 1:] / h2)
new_img[r, c] = np.sum(w * colors) / np.sum(w)
return new_img
开发者ID:dsvision,项目名称:nlm,代码行数:36,代码来源:nlm.py
示例18: check_neighbors
def check_neighbors(metric):
bt = BallTree(X, leaf_size=1, metric=metric)
dist1, ind1 = bt.query(Y, k)
dist2, ind2 = brute_force_neighbors(X, Y, k, metric)
assert_array_almost_equal(dist1, dist2)
开发者ID:Afey,项目名称:scikit-learn,代码行数:5,代码来源:test_ball_tree.py
注:本文中的sklearn.neighbors.ball_tree.BallTree类示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论