本文整理汇总了Python中sklearn.utils.testing.assert_in函数的典型用法代码示例。如果您正苦于以下问题:Python assert_in函数的具体用法?Python assert_in怎么用?Python assert_in使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了assert_in函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: test_root_import_all_completeness
def test_root_import_all_completeness():
EXCEPTIONS = ('utils', 'tests', 'base', 'setup')
for _, modname, _ in pkgutil.walk_packages(path=sklearn.__path__,
onerror=lambda _: None):
if '.' in modname or modname.startswith('_') or modname in EXCEPTIONS:
continue
assert_in(modname, sklearn.__all__)
开发者ID:MartinThoma,项目名称:scikit-learn,代码行数:7,代码来源:test_common.py
示例2: test_dump
def test_dump():
Xs, y = load_svmlight_file(datafile)
Xd = Xs.toarray()
for X in (Xs, Xd):
for zero_based in (True, False):
for dtype in [np.float32, np.float64]:
f = BytesIO()
dump_svmlight_file(X.astype(dtype), y, f, zero_based=zero_based)
f.seek(0)
comment = f.readline()
assert_in("scikit-learn %s" % sklearn.__version__, comment)
comment = f.readline()
assert_in(["one", "zero"][zero_based] + "-based", comment)
X2, y2 = load_svmlight_file(f, dtype=dtype, zero_based=zero_based)
assert_equal(X2.dtype, dtype)
if dtype == np.float32:
assert_array_almost_equal(
# allow a rounding error at the last decimal place
Xd.astype(dtype),
X2.toarray(),
4,
)
else:
assert_array_almost_equal(
# allow a rounding error at the last decimal place
Xd.astype(dtype),
X2.toarray(),
15,
)
assert_array_equal(y, y2)
开发者ID:kkuunnddaann,项目名称:scikit-learn,代码行数:33,代码来源:test_svmlight_format.py
示例3: test_dump
def test_dump():
Xs, y = load_svmlight_file(datafile)
Xd = Xs.toarray()
for X in (Xs, Xd):
for zero_based in (True, False):
for dtype in [np.float32, np.float64]:
f = BytesIO()
# we need to pass a comment to get the version info in;
# LibSVM doesn't grok comments so they're not put in by
# default anymore.
dump_svmlight_file(X.astype(dtype), y, f, comment="test",
zero_based=zero_based)
f.seek(0)
comment = f.readline()
assert_in("scikit-learn %s" % sklearn.__version__, comment)
comment = f.readline()
assert_in(["one", "zero"][zero_based] + "-based", comment)
X2, y2 = load_svmlight_file(f, dtype=dtype,
zero_based=zero_based)
assert_equal(X2.dtype, dtype)
if dtype == np.float32:
assert_array_almost_equal(
# allow a rounding error at the last decimal place
Xd.astype(dtype), X2.toarray(), 4)
else:
assert_array_almost_equal(
# allow a rounding error at the last decimal place
Xd.astype(dtype), X2.toarray(), 15)
assert_array_equal(y, y2)
开发者ID:yzhy,项目名称:scikit-learn,代码行数:32,代码来源:test_svmlight_format.py
示例4: test_big_input
def test_big_input():
"""Test if the warning for too large inputs is appropriate."""
X = np.repeat(10 ** 40., 4).astype(np.float64).reshape(-1, 1)
clf = DecisionTreeClassifier()
try:
clf.fit(X, [0, 1, 0, 1])
except ValueError as e:
assert_in("float32", str(e))
开发者ID:Carol-Hu,项目名称:scikit-learn,代码行数:8,代码来源:test_tree.py
示例5: test_sparse_precomputed
def test_sparse_precomputed():
clf = svm.SVC(kernel='precomputed')
sparse_gram = sparse.csr_matrix([[1, 0], [0, 1]])
try:
clf.fit(sparse_gram, [0, 1])
assert not "reached"
except TypeError as e:
assert_in("Sparse precomputed", str(e))
开发者ID:abhisg,项目名称:scikit-learn,代码行数:8,代码来源:test_svm.py
示例6: test_dump
def test_dump():
X_sparse, y_dense = load_svmlight_file(datafile)
X_dense = X_sparse.toarray()
y_sparse = sp.csr_matrix(y_dense)
# slicing a csr_matrix can unsort its .indices, so test that we sort
# those correctly
X_sliced = X_sparse[np.arange(X_sparse.shape[0])]
y_sliced = y_sparse[np.arange(y_sparse.shape[0])]
for X in (X_sparse, X_dense, X_sliced):
for y in (y_sparse, y_dense, y_sliced):
for zero_based in (True, False):
for dtype in [np.float32, np.float64, np.int32]:
f = BytesIO()
# we need to pass a comment to get the version info in;
# LibSVM doesn't grok comments so they're not put in by
# default anymore.
if (sp.issparse(y) and y.shape[0] == 1):
# make sure y's shape is: (n_samples, n_labels)
# when it is sparse
y = y.T
dump_svmlight_file(X.astype(dtype), y, f, comment="test",
zero_based=zero_based)
f.seek(0)
comment = f.readline()
comment = str(comment, "utf-8")
assert_in("scikit-learn %s" % sklearn.__version__, comment)
comment = f.readline()
comment = str(comment, "utf-8")
assert_in(["one", "zero"][zero_based] + "-based", comment)
X2, y2 = load_svmlight_file(f, dtype=dtype,
zero_based=zero_based)
assert_equal(X2.dtype, dtype)
assert_array_equal(X2.sorted_indices().indices, X2.indices)
X2_dense = X2.toarray()
if dtype == np.float32:
# allow a rounding error at the last decimal place
assert_array_almost_equal(
X_dense.astype(dtype), X2_dense, 4)
assert_array_almost_equal(
y_dense.astype(dtype), y2, 4)
else:
# allow a rounding error at the last decimal place
assert_array_almost_equal(
X_dense.astype(dtype), X2_dense, 15)
assert_array_almost_equal(
y_dense.astype(dtype), y2, 15)
开发者ID:mikebotazzo,项目名称:scikit-learn,代码行数:57,代码来源:test_svmlight_format.py
示例7: test_boundaries
def test_boundaries():
# ensure min_samples is inclusive of core point
core, _ = dbscan([[0], [1]], eps=2, min_samples=2)
assert_in(0, core)
# ensure eps is inclusive of circumference
core, _ = dbscan([[0], [1], [1]], eps=1, min_samples=2)
assert_in(0, core)
core, _ = dbscan([[0], [1], [1]], eps=.99, min_samples=2)
assert_not_in(0, core)
开发者ID:jorgedavid22,项目名称:scikit-learn,代码行数:9,代码来源:test_dbscan.py
示例8: check_parameters_default_constructible
def check_parameters_default_constructible(name, Estimator):
classifier = LDA()
# test default-constructibility
# get rid of deprecation warnings
with warnings.catch_warnings(record=True):
if name in META_ESTIMATORS:
estimator = Estimator(classifier)
else:
estimator = Estimator()
# test cloning
clone(estimator)
# test __repr__
repr(estimator)
# test that set_params returns self
assert_true(estimator.set_params() is estimator)
# test if init does nothing but set parameters
# this is important for grid_search etc.
# We get the default parameters from init and then
# compare these against the actual values of the attributes.
# this comes from getattr. Gets rid of deprecation decorator.
init = getattr(estimator.__init__, 'deprecated_original',
estimator.__init__)
try:
args, varargs, kws, defaults = inspect.getargspec(init)
except TypeError:
# init is not a python function.
# true for mixins
return
params = estimator.get_params()
if name in META_ESTIMATORS:
# they need a non-default argument
args = args[2:]
else:
args = args[1:]
if args:
# non-empty list
assert_equal(len(args), len(defaults))
else:
return
for arg, default in zip(args, defaults):
assert_in(type(default), [str, int, float, bool, tuple, type(None),
np.float64, types.FunctionType, Memory])
if arg not in params.keys():
# deprecated parameter, not in get_params
assert_true(default is None)
continue
if isinstance(params[arg], np.ndarray):
assert_array_equal(params[arg], default)
else:
assert_equal(params[arg], default)
开发者ID:Afey,项目名称:scikit-learn,代码行数:53,代码来源:estimator_checks.py
示例9: test_dump
def test_dump():
Xs, y = load_svmlight_file(datafile)
Xd = Xs.toarray()
# slicing a csr_matrix can unsort its .indices, so test that we sort
# those correctly
Xsliced = Xs[np.arange(Xs.shape[0])]
for X in (Xs, Xd, Xsliced):
for zero_based in (True, False):
for dtype in [np.float32, np.float64, np.int32]:
f = BytesIO()
# we need to pass a comment to get the version info in;
# LibSVM doesn't grok comments so they're not put in by
# default anymore.
dump_svmlight_file(X.astype(dtype), y, f, comment="test", zero_based=zero_based)
f.seek(0)
comment = f.readline()
try:
comment = str(comment, "utf-8")
except TypeError: # fails in Python 2.x
pass
assert_in("scikit-learn %s" % sklearn.__version__, comment)
comment = f.readline()
try:
comment = str(comment, "utf-8")
except TypeError: # fails in Python 2.x
pass
assert_in(["one", "zero"][zero_based] + "-based", comment)
X2, y2 = load_svmlight_file(f, dtype=dtype, zero_based=zero_based)
assert_equal(X2.dtype, dtype)
assert_array_equal(X2.sorted_indices().indices, X2.indices)
if dtype == np.float32:
assert_array_almost_equal(
# allow a rounding error at the last decimal place
Xd.astype(dtype),
X2.toarray(),
4,
)
else:
assert_array_almost_equal(
# allow a rounding error at the last decimal place
Xd.astype(dtype),
X2.toarray(),
15,
)
assert_array_equal(y, y2)
开发者ID:albertotb,项目名称:scikit-learn,代码行数:52,代码来源:test_svmlight_format.py
示例10: test_friedman_mse_in_graphviz
def test_friedman_mse_in_graphviz():
clf = DecisionTreeRegressor(criterion="friedman_mse", random_state=0)
clf.fit(X, y)
dot_data = StringIO()
export_graphviz(clf, out_file=dot_data)
clf = GradientBoostingClassifier(n_estimators=2, random_state=0)
clf.fit(X, y)
for estimator in clf.estimators_:
export_graphviz(estimator[0], out_file=dot_data)
for finding in finditer(r"\[.*?samples.*?\]", dot_data.getvalue()):
assert_in("friedman_mse", finding.group())
开发者ID:daniel-perry,项目名称:scikit-learn,代码行数:13,代码来源:test_export.py
示例11: test_countvectorizer_empty_vocabulary
def test_countvectorizer_empty_vocabulary():
try:
CountVectorizer(vocabulary=[])
assert False, "we shouldn't get here"
except ValueError as e:
assert_in("empty vocabulary", str(e).lower())
try:
v = CountVectorizer(max_df=1.0, stop_words="english")
# fit on stopwords only
v.fit(["to be or not to be", "and me too", "and so do you"])
assert False, "we shouldn't get here"
except ValueError as e:
assert_in("empty vocabulary", str(e).lower())
开发者ID:BloodD,项目名称:scikit-learn,代码行数:14,代码来源:test_text.py
示例12: test_download
def test_download():
"""Test that fetch_mldata is able to download and cache a data set."""
_urllib2_ref = datasets.mldata.urllib2
datasets.mldata.urllib2 = mock_urllib2({'mock':
{'label': sp.ones((150,)),
'data': sp.ones((150, 4))}})
try:
mock = fetch_mldata('mock', data_home=tmpdir)
assert_in(mock, in_=['COL_NAMES', 'DESCR', 'target', 'data'])
assert_equal(mock.target.shape, (150,))
assert_equal(mock.data.shape, (150, 4))
assert_raises(datasets.mldata.urllib2.HTTPError,
fetch_mldata, 'not_existing_name')
finally:
datasets.mldata.urllib2 = _urllib2_ref
开发者ID:QuarkSpark,项目名称:scikit-learn,代码行数:18,代码来源:test_mldata.py
示例13: test_n_iter_without_progress
def test_n_iter_without_progress():
# Use a dummy negative n_iter_without_progress and check output on stdout
random_state = check_random_state(0)
X = random_state.randn(100, 2)
tsne = TSNE(n_iter_without_progress=-1, verbose=2,
random_state=1, method='exact')
old_stdout = sys.stdout
sys.stdout = StringIO()
try:
tsne.fit_transform(X)
finally:
out = sys.stdout.getvalue()
sys.stdout.close()
sys.stdout = old_stdout
# The output needs to contain the value of n_iter_without_progress
assert_in("did not make any progress during the "
"last -1 episodes. Finished.", out)
开发者ID:AlexandreAbraham,项目名称:scikit-learn,代码行数:19,代码来源:test_t_sne.py
示例14: test_n_iter_without_progress
def test_n_iter_without_progress():
# Make sure that the parameter n_iter_without_progress is used correctly
random_state = check_random_state(0)
X = random_state.randn(100, 2)
tsne = TSNE(n_iter_without_progress=2, verbose=2,
random_state=0, method='exact')
old_stdout = sys.stdout
sys.stdout = StringIO()
try:
tsne.fit_transform(X)
finally:
out = sys.stdout.getvalue()
sys.stdout.close()
sys.stdout = old_stdout
# The output needs to contain the value of n_iter_without_progress
assert_in("did not make any progress during the "
"last 2 episodes. Finished.", out)
开发者ID:ManrajGrover,项目名称:scikit-learn,代码行数:19,代码来源:test_t_sne.py
示例15: test_fetch_one_column
def test_fetch_one_column():
_urllib2_ref = datasets.mldata.urllib2
try:
dataname = 'onecol'
# create fake data set in cache
x = sp.arange(6).reshape(2, 3)
datasets.mldata.urllib2 = mock_urllib2({dataname: {'x': x}})
dset = fetch_mldata(dataname, data_home=tmpdir)
assert_in(dset, in_=['COL_NAMES', 'DESCR', 'data'], out_=['target'])
assert_equal(dset.data.shape, (2, 3))
assert_array_equal(dset.data, x)
# transposing the data array
dset = fetch_mldata(dataname, transpose_data=False, data_home=tmpdir)
assert_equal(dset.data.shape, (3, 2))
finally:
datasets.mldata.urllib2 = _urllib2_ref
开发者ID:QuarkSpark,项目名称:scikit-learn,代码行数:19,代码来源:test_mldata.py
示例16: test_unseen_or_no_features
def test_unseen_or_no_features():
D = [{"camelot": 0, "spamalot": 1}]
for sparse in [True, False]:
v = DictVectorizer(sparse=sparse).fit(D)
X = v.transform({"push the pram a lot": 2})
if sparse:
X = X.toarray()
assert_array_equal(X, np.zeros((1, 2)))
X = v.transform({})
if sparse:
X = X.toarray()
assert_array_equal(X, np.zeros((1, 2)))
try:
v.transform([])
except ValueError as e:
assert_in("empty", str(e))
开发者ID:0664j35t3r,项目名称:scikit-learn,代码行数:19,代码来源:test_dict_vectorizer.py
示例17: test_valid_brute_metric_for_auto_algorithm
def test_valid_brute_metric_for_auto_algorithm():
X = rng.rand(12, 12)
Xcsr = csr_matrix(X)
# check that there is a metric that is valid for brute
# but not ball_tree (so we actually test something)
assert_in("cosine", VALID_METRICS['brute'])
assert_false("cosine" in VALID_METRICS['ball_tree'])
# Metric which don't required any additional parameter
require_params = ['mahalanobis', 'wminkowski', 'seuclidean']
for metric in VALID_METRICS['brute']:
if metric != 'precomputed' and metric not in require_params:
nn = neighbors.NearestNeighbors(n_neighbors=3, algorithm='auto',
metric=metric).fit(X)
nn.kneighbors(X)
elif metric == 'precomputed':
X_precomputed = rng.random_sample((10, 4))
Y_precomputed = rng.random_sample((3, 4))
DXX = metrics.pairwise_distances(X_precomputed, metric='euclidean')
DYX = metrics.pairwise_distances(Y_precomputed, X_precomputed,
metric='euclidean')
nb_p = neighbors.NearestNeighbors(n_neighbors=3)
nb_p.fit(DXX)
nb_p.kneighbors(DYX)
for metric in VALID_METRICS_SPARSE['brute']:
if metric != 'precomputed' and metric not in require_params:
nn = neighbors.NearestNeighbors(n_neighbors=3, algorithm='auto',
metric=metric).fit(Xcsr)
nn.kneighbors(Xcsr)
# Metric with parameter
VI = np.dot(X, X.T)
list_metrics = [('seuclidean', dict(V=rng.rand(12))),
('wminkowski', dict(w=rng.rand(12))),
('mahalanobis', dict(VI=VI))]
for metric, params in list_metrics:
nn = neighbors.NearestNeighbors(n_neighbors=3, algorithm='auto',
metric=metric,
metric_params=params).fit(X)
nn.kneighbors(X)
开发者ID:BasilBeirouti,项目名称:scikit-learn,代码行数:42,代码来源:test_neighbors.py
示例18: test_download
def test_download(tmpdata):
"""Test that fetch_mldata is able to download and cache a data set."""
_urlopen_ref = datasets.mldata.urlopen
datasets.mldata.urlopen = mock_mldata_urlopen({
'mock': {
'label': sp.ones((150,)),
'data': sp.ones((150, 4)),
},
})
try:
mock = fetch_mldata('mock', data_home=tmpdata)
for n in ["COL_NAMES", "DESCR", "target", "data"]:
assert_in(n, mock)
assert_equal(mock.target.shape, (150,))
assert_equal(mock.data.shape, (150, 4))
assert_raises(datasets.mldata.HTTPError,
fetch_mldata, 'not_existing_name')
finally:
datasets.mldata.urlopen = _urlopen_ref
开发者ID:AlexisMignon,项目名称:scikit-learn,代码行数:21,代码来源:test_mldata.py
示例19: test_fetch_one_column
def test_fetch_one_column():
_urlopen_ref = datasets.mldata.urlopen
try:
dataname = 'onecol'
# create fake data set in cache
x = sp.arange(6).reshape(2, 3)
datasets.mldata.urlopen = mock_mldata_urlopen({dataname: {'x': x}})
dset = fetch_mldata(dataname, data_home=tmpdir)
for n in ["COL_NAMES", "DESCR", "data"]:
assert_in(n, dset)
assert_not_in("target", dset)
assert_equal(dset.data.shape, (2, 3))
assert_array_equal(dset.data, x)
# transposing the data array
dset = fetch_mldata(dataname, transpose_data=False, data_home=tmpdir)
assert_equal(dset.data.shape, (3, 2))
finally:
datasets.mldata.urlopen = _urlopen_ref
开发者ID:Ranumao,项目名称:scikit-learn,代码行数:21,代码来源:test_mldata.py
示例20: test_n_iter_without_progress
def test_n_iter_without_progress():
# Use a dummy negative n_iter_without_progress and check output on stdout
random_state = check_random_state(0)
X = random_state.randn(100, 10)
for method in ["barnes_hut", "exact"]:
tsne = TSNE(n_iter_without_progress=-1, verbose=2, learning_rate=1e8,
random_state=0, method=method, n_iter=351, init="random")
tsne._N_ITER_CHECK = 1
tsne._EXPLORATION_N_ITER = 0
old_stdout = sys.stdout
sys.stdout = StringIO()
try:
tsne.fit_transform(X)
finally:
out = sys.stdout.getvalue()
sys.stdout.close()
sys.stdout = old_stdout
# The output needs to contain the value of n_iter_without_progress
assert_in("did not make any progress during the "
"last -1 episodes. Finished.", out)
开发者ID:BasilBeirouti,项目名称:scikit-learn,代码行数:22,代码来源:test_t_sne.py
注:本文中的sklearn.utils.testing.assert_in函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论