• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    迪恩网络公众号

Python tvm.get_global_func函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tvm.get_global_func函数的典型用法代码示例。如果您正苦于以下问题:Python get_global_func函数的具体用法?Python get_global_func怎么用?Python get_global_func使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了get_global_func函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: verify

    def verify(target="llvm",
               algorithm=nnpack.ConvolutionAlgorithm.AUTO,
               with_bias=True):
        if not tvm.module.enabled(target):
            print("skip because %s is not enabled..." % target)
            return
        if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
            print("skip because extern function is not available")
            return
        if not nnpack.is_available():
            return

        ctx = tvm.cpu(0)
        transformed_kernel = nnpack.convolution_inference_weight_transform(
            kernel, algorithm=algorithm)
        output = nnpack.convolution_inference_without_weight_transform(
            data, transformed_kernel, bias if with_bias else None,
            [PAD, PAD, PAD, PAD], [STRIDE, STRIDE],
            algorithm=algorithm)

        s = tvm.create_schedule(output.op)

        f = tvm.build(s, [data, kernel, bias, output], target)

        na = np.random.uniform(size=dshape).astype(data.dtype)
        nb = np.random.uniform(size=kshape).astype(kernel.dtype)
        nc = np.random.uniform(size=bshape).astype(bias.dtype) if with_bias else np.zeros(bshape, dtype=bias.dtype)
        ta = tvm.nd.array(na, ctx)
        tb = tvm.nd.array(nb, ctx)
        tc = tvm.nd.array(nc, ctx)
        td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx)
        f(ta, tb, tc, td)
        nd = np_conv(np.reshape(na, (BATCH, IC, IH, IW)), nb, PAD, STRIDE) + nc.reshape(1, bshape[0], 1, 1)
        tvm.testing.assert_allclose(
            td.asnumpy(), nd.reshape(BATCH, IC, IH, IW), rtol=1e-5)
开发者ID:LANHUIYING,项目名称:tvm,代码行数:35,代码来源:test_nnpack.py


示例2: check

 def check(factor):
     s = tvm.create_schedule(z.op)
     xo, xi = s[z].split(z.op.axis[0], factor=factor)
     vadd = intrin_vadd(factor)
     s[z].tensorize(xi, vadd)
     s = s.normalize()
     dom_map = tvm.schedule.InferBound(s)
     finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
     out_dom, in_dom = finfer(s[z], dom_map)
     assert tvm.ir_pass.Equal(out_dom[z.op.axis[0]].extent, factor)
     assert tvm.ir_pass.Equal(out_dom[z.op.axis[0]].min, xo * factor)
     assert tvm.ir_pass.Equal(in_dom.items()[0][1][0].extent, factor)
     fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
     body = fmatch(s[z], out_dom, in_dom, vadd)
     assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
                              tvm.ir_pass.CanonicalSimplify(vadd.op.body[0]))
     stmt = tvm.schedule.ScheduleOps(s, dom_map)
     tvm.lower(s, [x, y, z])
开发者ID:bddppq,项目名称:tvm,代码行数:18,代码来源:test_schedule_tensorize.py


示例3: test_conv2d

def test_conv2d():
    in_channel = 3
    out_channel = 64
    filter_h = 3
    filter_w = 3
    pad_h = 1
    pad_w = 1
    stride_h = 1
    stride_w = 1
    dilation_h = 1
    dilation_w = 1

    xshape = [1, in_channel, 128, 128]
    if not tvm.module.enabled("rocm"):
        print("skip because rocm is not enabled...")
        return
    if not tvm.get_global_func("tvm.contrib.miopen.conv2d.setup", True):
        print("skip because miopen is not enabled...")
        return
    wshape = (out_channel, in_channel, filter_h, filter_w)

    X = tvm.placeholder(xshape, name='X')
    W = tvm.placeholder(wshape, name='W')
    Y = miopen.conv2d_forward(X,
                              W,
                              stride_h,
                              stride_w,
                              pad_h,
                              pad_w,
                              dilation_h,
                              dilation_w,
                              conv_mode=0)

    yshape = [x.value for x in Y.shape]
    import topi
    with tvm.target.create("rocm -libs=miopen"):
        s = topi.generic.schedule_extern(Y)

    def verify():
        ctx = tvm.rocm(0)
        f = tvm.build(s, [X, W, Y], "rocm", target_host="llvm", name="conv2d")
        x = tvm.nd.array(np.random.uniform(-1, 1, xshape).astype(np.float32), ctx)
        w = tvm.nd.array(np.random.uniform(-1, 1, wshape).astype(np.float32), ctx)
        y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
        f(x, w, y)

        Y_ref = topi.nn.conv2d_nchw(X, W, (stride_h, stride_w), (pad_h, pad_w))
        with tvm.target.rocm():
            s_ref = topi.generic.schedule_conv2d_nchw([Y_ref])
        f_ref = tvm.build(s_ref, [X, W, Y_ref], "rocm")
        y_ref = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32), ctx)
        f_ref(x, w, y_ref)
        print("Max abs diff:", np.max(np.abs(y.asnumpy() - y_ref.asnumpy())))
        tvm.testing.assert_allclose(y.asnumpy(), y_ref.asnumpy(), atol=1e-3)

    verify()
开发者ID:LANHUIYING,项目名称:tvm,代码行数:56,代码来源:test_miopen.py


示例4: stats

def stats():
    """Clear profiler statistics

    Returns
    -------
    stats : dict
        Current profiler statistics
    """
    x = tvm.get_global_func("vta.simulator.profiler_status")()
    return json.loads(x)
开发者ID:LANHUIYING,项目名称:tvm,代码行数:10,代码来源:simulator.py


示例5: verify

 def verify(A, B, C, target="llvm"):
     if not tvm.get_global_func("tvm.contrib.mps.conv2d", True):
         print("skip because extern function is not available")
         return
     ctx = tvm.metal(0)
     f = tvm.build(s1, [A, B, C], "metal")
     a = tvm.nd.array(np.random.uniform(size=(n, h, w, ci)).astype(A.dtype), ctx)
     b = tvm.nd.array(np.random.uniform(size=(co, kh, kw, ci)).astype(B.dtype), ctx)
     c = tvm.nd.array(np.zeros((n, h // stride, w // stride, co), dtype=C.dtype), ctx)
     f(a, b, c)
开发者ID:LANHUIYING,项目名称:tvm,代码行数:10,代码来源:test_mps.py


示例6: test_get_global

def test_get_global():
    targs = (10, 10.0, "hello")
    # register into global function table
    @tvm.register_func
    def my_packed_func(*args):
        assert(tuple(args) == targs)
        return 10
    # get it out from global function table
    f = tvm.get_global_func("my_packed_func")
    assert isinstance(f, tvm.Function)
    y = f(*targs)
    assert y == 10
开发者ID:bddppq,项目名称:tvm,代码行数:12,代码来源:test_runtime_packed_func.py


示例7: test_conv2d

def test_conv2d():
    in_channel = 3
    out_channel = 32
    filter_h = 3
    filter_w = 3
    pad_h = 1
    pad_w = 1
    stride_h = 1
    stride_w = 1
    dilation_h = 1
    dilation_w = 1

    xshape = [4, 3, 32, 32]
    if not tvm.module.enabled("cuda"):
        print("skip because cuda is not enabled...")
        return
    if not tvm.get_global_func("tvm.contrib.cudnn.conv2d.output_shape", True):
        print("skip because cudnn is not enabled...")
        return
    wshape = cudnn.conv2d_w_shape(in_channel,
                              out_channel,
                              filter_h,
                              filter_w)

    X = tvm.placeholder(xshape, name='X')
    W = tvm.placeholder(wshape, name='W')
    Y = cudnn.conv2d_forward(X,
                             W,
                             stride_h,
                             stride_w,
                             pad_h,
                             pad_w,
                             dilation_h,
                             dilation_w,
                             conv_mode=1,
                             tensor_format=0,
                             algo=1)
    yshape = [x.value for x in Y.shape]
    s =  tvm.create_schedule(Y.op)

    def verify():
        ctx = tvm.gpu(0)
        f = tvm.build(s, [X, W, Y], "cuda", target_host="llvm", name="conv2d")
        x = tvm.nd.array(np.random.uniform(-1, 1, xshape).astype(np.float32),
                         ctx)
        w = tvm.nd.array(np.random.uniform(-1, 1, wshape).astype(np.float32),
                         ctx)
        y = tvm.nd.array(np.random.uniform(-1, 1, yshape).astype(np.float32),
                         ctx)
        f(x, w, y)

    verify()
开发者ID:bddppq,项目名称:tvm,代码行数:52,代码来源:test_cudnn.py


示例8: check_rfactor_no_reset_multi_reduction

 def check_rfactor_no_reset_multi_reduction(factor, rfactor):
     s = tvm.create_schedule(C.op)
     x, y = C.op.axis
     rk = C.op.reduce_axis[0]
     yo, yi = s[C].split(y, factor=factor)
     ro, ri = s[C].split(rk, factor=rfactor)
     roo, roi = s[C].split(ro, factor=2)
     s[C].reorder(yo, roo, roi, yi, ri)
     gemv = intrin_gemv_no_reset(factor, rfactor)
     s[C].tensorize(yi, gemv)
     s = s.normalize()
     dom_map = tvm.schedule.InferBound(s)
     finfer = tvm.get_global_func("test.op.InferTensorizeRegion")
     out_dom, in_dom = finfer(s[C], dom_map)
     assert tvm.ir_pass.Equal(out_dom[x].extent, 1)
     assert tvm.ir_pass.Equal(out_dom[y].extent, factor)
     assert tvm.ir_pass.Equal(out_dom[y].min, yo * factor)
     fmatch = tvm.get_global_func("test.op.MatchTensorizeBody")
     body = fmatch(s[C], out_dom, in_dom, gemv)
     assert tvm.ir_pass.Equal(tvm.ir_pass.CanonicalSimplify(body[0]),
                              tvm.ir_pass.CanonicalSimplify(gemv.op.body[0]))
     stmt = tvm.schedule.ScheduleOps(s, dom_map)
     tvm.lower(s, [A, B, C])
开发者ID:bddppq,项目名称:tvm,代码行数:23,代码来源:test_schedule_tensorize.py


示例9: verify

 def verify(target="llvm"):
     if not tvm.module.enabled(target):
         print("skip because %s is not enabled..." % target)
         return
     if not tvm.get_global_func("tvm.contrib.random.normal", True):
         print("skip because extern function is not available")
         return
     ctx = tvm.cpu(0)
     f = tvm.build(s, [A], target)
     a = tvm.nd.array(np.zeros((m, n), dtype=A.dtype), ctx)
     f(a)
     na = a.asnumpy()
     assert abs(np.mean(na) - 3) < 1e-2
     assert abs(np.std(na) - 4) < 1e-2
开发者ID:LANHUIYING,项目名称:tvm,代码行数:14,代码来源:test_random.py


示例10: verify

 def verify(target="rocm"):
     if not tvm.module.enabled(target):
         print("skip because %s is not enabled..." % target)
         return
     if not tvm.get_global_func("tvm.contrib.rocblas.matmul", True):
         print("skip because extern function is not available")
         return
     ctx = tvm.rocm(0)
     f = tvm.build(s, [A, B, C], target)
     a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
     b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
     c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
     f(a, b, c)
     tvm.testing.assert_allclose(
         c.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()), rtol=1e-5)
开发者ID:LANHUIYING,项目名称:tvm,代码行数:15,代码来源:test_rocblas.py


示例11: verify

 def verify(target="llvm"):
     if not tvm.module.enabled(target):
         print("skip because %s is not enabled..." % target)
         return
     if not tvm.get_global_func("tvm.contrib.cblas.matmul", True):
         print("skip because extern function is not avalable")
         return
     ctx = tvm.cpu(0)
     f = tvm.build(s, [A, B, D, bias], target)
     a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx)
     b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), ctx)
     d = tvm.nd.array(np.zeros((n, m), dtype=D.dtype), ctx)
     bb = 10.0
     f(a, b, d, bb)
     np.testing.assert_allclose(
         d.asnumpy(), np.dot(a.asnumpy(), b.asnumpy()) + bb, rtol=1e-5)
开发者ID:gwli,项目名称:tvm,代码行数:16,代码来源:test_cblas.py


示例12: test_get_callback_with_node

def test_get_callback_with_node():
    x = tvm.convert(10)
    def test(y):
        assert y.handle != x.handle
        return y

    f2 = tvm.convert(test)
    # register into global function table
    @tvm.register_func
    def my_callback_with_node(y, f):
        assert y == x
        return f(y)

    # get it out from global function table
    f = tvm.get_global_func("my_callback_with_node")
    assert isinstance(f, tvm.Function)
    y = f(x, f2)
    assert(y.value == 10)
开发者ID:bddppq,项目名称:tvm,代码行数:18,代码来源:test_runtime_packed_func.py


示例13: verify

    def verify(target="llvm"):
        if not tvm.module.enabled(target):
            print("skip because %s is not enabled..." % target)
            return
        if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
            print("skip because extern function is not avalable")
            return
        ctx = tvm.cpu(0)
        f = tvm.build(s, [data, kernel, bias, output], target)

        na = np.random.uniform(size=dshape).astype(data.dtype)
        nb = np.random.uniform(size=kshape).astype(kernel.dtype)
        nc = np.zeros(bshape, dtype=bias.dtype)
        ta = tvm.nd.array(na, ctx)
        tb = tvm.nd.array(nb, ctx)
        tc = tvm.nd.array(nc, ctx)
        td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx)
        f(ta, tb, tc, td)
        nd = np_conv(na, nb, PAD)
        np.testing.assert_allclose(
            td.asnumpy(), nd, rtol=1e-5)
开发者ID:gwli,项目名称:tvm,代码行数:21,代码来源:test_nnpack.py


示例14: test_env_func

def test_env_func():
    @tvm.register_func("test.env_func")
    def test(x):
        return x + 1

    f = tvm.get_global_func("test.env_func")
    x = tvm.get_env_func("test.env_func")
    assert x.name == "test.env_func"
    json_str = tvm.save_json([x])
    y = tvm.load_json(json_str)[0]
    assert y.name == x.name
    assert y(1) == 2
    assert y.func(1) == 2

    x = tvm.make.node("attrs.TestAttrs", name="xx", padding=(3,4), func=y)
    assert x.name == "xx"
    assert x.padding[0].value == 3
    assert x.padding[1].value == 4
    assert x.axis == 10
    x = tvm.load_json(tvm.save_json(x))
    assert isinstance(x.func, tvm.container.EnvFunc)
    assert x.func(10) == 11
开发者ID:bddppq,项目名称:tvm,代码行数:22,代码来源:test_lang_reflection.py


示例15: save_tensors

def save_tensors(params):
    """Save parameter dictionary to binary bytes.

    The result binary bytes can be loaded by the
    GraphModule with API "load_params".

    Parameters
    ----------
    params : dict of str to NDArray
        The parameter dictionary.

    Returns
    -------
    param_bytes: bytearray
        Serialized parameters.
    """
    _save_tensors = tvm.get_global_func("_save_param_dict")

    args = []
    for k, v in params.items():
        args.append(k)
        args.append(tvm.nd.array(v))
    return _save_tensors(*args)
开发者ID:LANHUIYING,项目名称:tvm,代码行数:23,代码来源:debug_result.py


示例16: save_param_dict

# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Helper utility to save parameter dict"""
import tvm

_save_param_dict = tvm.get_global_func("nnvm.compiler._save_param_dict")
_load_param_dict = tvm.get_global_func("nnvm.compiler._load_param_dict")

def save_param_dict(params):
    """Save parameter dictionary to binary bytes.

    The result binary bytes can be loaded by the
    GraphModule with API "load_params".

    Parameters
    ----------
    params : dict of str to NDArray
        The parameter dictionary.

    Returns
    -------
开发者ID:bddppq,项目名称:tvm,代码行数:31,代码来源:param_dict.py


示例17: GraphKey

#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Compiler engine interface to internal engine

You can get the engine singleton at ``nnvm.compiler.engine``
"""
import tvm

_list_cache_items = tvm.get_global_func("nnvm.compiler.ListCacheItems")
_clear_cache = tvm.get_global_func("nnvm.compiler.ClearCache")
_get_cache_item = tvm.get_global_func("nnvm.compiler.GetCacheItem")
_set_cache_item = tvm.get_global_func("nnvm.compiler.SetCacheItem")
_graph_key_get_graph = tvm.get_global_func("nnvm.compiler.GraphKeyGetGraph")
_make_graph_key = tvm.get_global_func("nnvm.compiler.MakeGraphKey")

@tvm.register_node
class GraphKey(tvm.node.NodeBase):
    """Key of a graph compilation context"""
    @property
    def graph(self):
        return _graph_key_get_graph(self)


@tvm.register_node
开发者ID:bddppq,项目名称:tvm,代码行数:31,代码来源:compile_engine.py


示例18: check_graph_equal

    out_dtype: list of tuple
         Dtype of outputs
    """
    graph = graph_attr.set_dtype_inputs(graph, dtype)
    graph = graph.apply("InferType")
    dtype = graph.json_attr("dtype")
    index = graph.index
    input_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
                   for x in index.input_names]
    output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
                    for x in index.output_entries]
    return input_dtype, output_dtype


_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare")

def check_graph_equal(grapha, graphb, compare_variable_attrs=False):
    """Check if two graphs have equal structure.

    Parameters
    ----------
    grapha : Graph
        The first graph

    graphb : Graph
        The second graph

    compare_variable_attrs : bool, optional
        Whether we want to compare attributes(names) on variables.
        Usually it is safe to skip it unless we want input name
开发者ID:bddppq,项目名称:tvm,代码行数:30,代码来源:graph_util.py


示例19: AttrDict

# pylint: disable=invalid-name
"""Attr dictionary object used by schedule functions"""
import tvm

_dict_get = tvm.get_global_func("nnvm.compiler._dict_get")
_dict_size = tvm.get_global_func("nnvm.compiler._dict_size")
_dict_keys = tvm.get_global_func("nnvm.compiler._dict_keys")

class AttrDict(object):
    """Attribute dictionary in nnvm.

    Used by python registration of compute and schedule function.
    AttrDict is passed as the first argument to schedule and compute function.
    """
    _tvm_tcode = 18

    def __init__(self, handle):
        self.handle = handle

    def __del__(self):
        tvm.nd.free_extension_handle(self.handle, 18)

    @property
    def _tvm_handle(self):
        return self.handle.value

    def __getitem__(self, key):
        return _dict_get(self, key)

    def keys(self):
        """Get list of keys in the dict.
开发者ID:masa-ito-fj,项目名称:nnvm,代码行数:31,代码来源:attr_dict.py


示例20: program_fpga

 def program_fpga(file_name):
     path = tvm.get_global_func("tvm.rpc.server.workpath")(file_name)
     bitstream = Bitstream(path)
     bitstream.download()
     logging.info("Program FPGA with %s", file_name)
开发者ID:LANHUIYING,项目名称:tvm,代码行数:5,代码来源:rpc_server.py



注:本文中的tvm.get_global_func函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python tvm.lower函数代码示例发布时间:2022-05-27
下一篇:
Python tvm.decl_buffer函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap