Skip to content

Conversation

kurisu6912
Copy link
Collaborator

@kurisu6912 kurisu6912 commented Oct 13, 2025

Tilelang JITv2

In this PR we introduce Tilelang JITv2, a new frontend for Tilelang with modern and attractive features.

Features

Kernel Declaration

Function declaration has been simplified:

image

Kernel Call

When calling functions, tensor shapes, strides, and dtypes are automatically inferred:

# before
ker_1 = matmul(1024, 1024, 1024, 'float32')
c1 = ker_1(a1, b1)
ker_2 = matmul(1024, 1024, 512, 'float32')
c2 = ker_2(a2, b2)

# after
gemm(a1, b1)
gemm(a2, b2)

Auto Tuning

Auto tuning can be done via default arguments:

@tl.jit
def add(
    A: tl.Tensor[int],
    B: tl.Tensor[int],
    block: int = tune([128, 256, 512])
):
    ...

Or on-the-fly:

add(A, B, tune([64, 128]))

Smarter Static Evaluation

JITv2 preserves as much Python code as possible, allowing calls to custom Python functions or conditional kernel generation:

@tl.jit
def gemm(
    ...
    split_k: bool = False
):
    block_size = my_super_block_size_huristic(M, N, K)
    if split_k: # split_k is a constant value
        with tl.Kernel(...) as ...:
            ...
    else:
        with tl.Kernel(...) as ...:
            ...

    return C

Smarter Type Hinting

JITv2 not only eliminates annoying type warnings, but also adds extensive type annotations. This helps you clearly see each Tensor’s dimensions, and marks whether a value is on the Python side or kernel side. Even generated functions and the JIT-compiled kernels have friendly type hints:

image

Extremely Low Overhead

JITv2's Python overhead has been optimized to the extreme. In the fast path, only dynamic parameters are checked, bringing overhead in line with calling a torch function (e.g., torch.add):

A = torch.randn(128, dtype=torch.float16, device="cuda")
B = torch.randn(128, dtype=torch.float16, device="cuda")

# torch.add:  ~ 6.5us
C_1 = A + B
# jit kernel: ~ 7.5us (cached)
C_2 = add(A, B)

Architecture

The Tilelang JIT workflow:

  1. Py-to-Py generates two pieces of code: argument parser and JIT function generator
  2. Fast path (~1.5 μs): Calls the kernel, argument parser separates static and dynamic parameters; static cache hit → directly calls C++ library functions
  3. Slow path: Static cache miss → kernel needs to be recompiled
whiteboard_exported_image (1)

Static & Dynamic Arguments

JITv2 inspects function signatures to determine which parameters are const and which are dyn:

  • dyn supports only int, float, and ptr; treated as tir.Var
  • const can be any type (simple types preferred; prefer Tuple over List)
  • dyn types must be explicitly annotated; Tensor must be explicitly annotated because its data_ptr is always dynamic
  • const arguments can differ from annotation (e.g., annotate int but pass a dict) — note: validation is hard (like writing a pydantic)
whiteboard_exported_image (2)

Argument Parser

JITv2 generates Python code for the fast path, which unpacks const and dyn arguments and then invokes the kernel:

  • Optimized Python statements: Each statement in the fast path is carefully designed, using bytecode fast to execute — overhead is minimal, even slightly smaller than torch.to_dlpack
  • Static check cache: The fast path does not perform type checks for const variables; instead, these are checked at compile time (e.g., wrong tensor shape → cache miss → kernel compiled → value range check)
  • Dynamic type checks: Fast path performs simple dynamic checks, e.g., asserting equal values for K. More complex asserts may be compiled to host code (not yet supported)
_K = dyn[int, '_K']
def foo(
    a: Tensor[int, _K],
    b: Tensor[int, _K],
    c: int,
):
    pass
# generated code
def foo_fastpath(a, b, c):
    # 1. Unpack type info
    # 1.1 Unpacking a tensor ~600 ns; each of the following lines takes ~200 ns, heavily optimized
    assert a.device != __device_cpu__, "Expected a non CPU tensor"
    a__shape_0, a__shape_1 = a.shape
    a__stride_0, a__stride_1 = a.stride()
    assert b.device != __device_cpu__, "Expected a non CPU tensor"
    #                  ^- note: torch.device('cpu') costs 200+ ns; using closure trick, __device_cpu__ costs 5 ns
    b__shape_0, b__shape_1 = b.shape
    b__stride_0, b__stride_1 = b.stride()
    # 2. Construct argument lists ~20–50 ns
    __const_args__ = (
        a.dtype, a__shape_0, a__shape_1, a__stride_0, a__stride_1,
        b.dtype, b__shape_0, b__shape_1, b__stride_0, b__stride_1,
        c)
    __dyn_args__ = (a.data_ptr(), b.data_ptr())
    return __const_args, __dyn_args__

Memory Allocation & Return Values

Inside functions, use T.alloc_global to create global buffers:

  • T.alloc_global is friendlier for type linting — it’s translated into torch.empty
  • T.alloc_xxx must be assigned to a variable (x = T.alloc_xxx()), not passed directly as a function parameter (e.g., foo(T.alloc_shared(...)) is not allowed)
  • Return objects must be global buffers; returning Python objects is not supported (e.g., returning BLOCK_M + BLOCK_N is not allowed):
@T.prim_func
def gemm(
    A: T.Tensor[int, int],
    B: T.Tensor[int, int],
    out_ty  = torch.half,
    BLOCK_M = T.tune([64, 128, 256]),
    BLOCK_N = T.tune([64, 128, 256]),
):
    # Quickly get dimensions
    (N, K), (M, K2) = A.shape, B.shape
    assert K == K2, "Expect 2 matrices with identical K dimension"
    # Allocate memory for output
    out = T.alloc_global((N, M), dtype=out_ty)
    with T.Kernel((T.ceildiv(M, BLOCK_M), T.ceildiv(N, BLOCK_N)), threads=128) as (bx, by):
        pass
    return out

TODOs

  • Integrate with tl.language
  • Add auto tuner

Summary by CodeRabbit

  • New Features

    • New dtype interop utilities and public dtype aliases; v2 exposes tensor schemas, placement helpers, empty-tensor descriptors, JIT/compile/tune APIs, and AST rewriting tools.
    • JIT framework: decorator, compile/tune/benchmark workflow, parallel compilation, and autotuning result types.
    • New example notebook demonstrating JITv2 usage and migration guidance.
  • Improvements

    • Consistent dtype normalization across allocations, views, proxies, and intrinsic ops.
    • Consolidated v2 re-exports for easier imports.
  • Profiler

    • Benchmarks support timeouts and default backend switched to “cupti.”

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 13, 2025

Walkthrough

Adds a dtype interop module and normalizes dtypes across language APIs; introduces TileLang v2 (AST rewriter, DSL builder, compile/JIT pipeline, tensor/schema types) with aggregated re-exports; updates many TIR intrinsics to normalize dtype strings; and adds a signal-based timeout to the benchmarking utility and changed do_bench signature/defaults.

Changes

Cohort / File(s) Summary
DType interop core
tilelang/language/dtypes.py, tilelang/language/__init__.py
New dtype utilities providing AnyDType, conversions (get_tvm_dtype, get_torch_dtype, etc.), pointer helpers; symbolic now accepts AnyDType and normalizes via get_tvm_dtype; re-exports added.
Allocation / view / proxy
tilelang/language/allocate.py, tilelang/language/customize.py, tilelang/language/proxy.py
Allocation and view APIs annotated to accept AnyDType; runtime normalization via get_tvm_dtype before creating buffers/tensors; small bool→shared mapping in alloc_shared.
TIR op dtype normalization
tilelang/language/tir/op.py
Many intrinsic wrappers now call get_tvm_dtype_str to normalize dtype args before delegating to TVM intrinsics (ptx/mfma/vector/type helpers, annotations, reinterpret, etc.).
TileLang v2 — core modules
tilelang/language/v2/ast_rewrite.py, tilelang/language/v2/compile.py, tilelang/language/v2/jit.py, tilelang/language/v2/lang.py
New v2 stack: AST quoting & DSLMutator, DSLBuilder/IR generation, JIT compilation/tuning (JITKernel/JITDispatcher), and tensor/schema types (Dyn/Const/Strided/Tensor, MakeEmpty, Place, tune). Many new public classes/functions.
TileLang v2 API aggregation
tilelang/language/v2/__init__.py, tilelang/language/v2/v1.py
Re-exports v2 lang/compile/jit symbols and all public v1 symbols to present unified v2 API surface; purely import-based.
Profiler benchmarking timeout
tilelang/profiler/bench.py
Adds TimeoutException, timeout_handler, run_with_timeout; do_bench renamed arg to bench_fn, accepts optional timeout, and default backend set to "cupti"; integrates signal-based timeout path.
Examples
examples/jitv2/jitv2.ipynb
New Jupyter notebook demonstrating JITv2 usage, tuning, and migration examples.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant PyFunc as User Function
  participant AST as v2.ast_rewrite
  participant Builder as v2.compile.DSLBuilder
  participant JIT as v2.jit.JITDispatcher
  participant Compiler as v2.jit.compile
  participant Kernel as JITKernel

  User->>JIT: @jit(...)(PyFunc)
  JIT->>AST: quote/DSLMutator(PyFunc)
  AST-->>Builder: Transformed AST
  Builder->>Builder: build IR / prim_func
  JIT->>Compiler: compile(JITFunc)
  Compiler-->>Kernel: produce JITKernel (lib, wrapper)
  User->>Kernel: call(args)
  Kernel->>Kernel: dispatch to runtime wrapper
  Kernel-->>User: results
Loading
sequenceDiagram
  autonumber
  actor Caller
  participant Bench as profiler.bench.do_bench
  participant Runner as bench_fn
  participant Signal as signal.alarm

  Caller->>Bench: do_bench(bench_fn, timeout=T)
  alt timeout specified
    Bench->>Signal: set alarm(T)
    Bench->>Runner: run_with_timeout(Runner)
    Runner-->>Bench: result or TimeoutException
    Bench->>Signal: cancel alarm
  else no timeout
    Bench->>Runner: run()
  end
  Bench-->>Caller: latency or samples
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Suggested reviewers

  • LeiWang1999

Poem

I hop through dtypes, from Torch to TVM light,
I nudge kernels into v2's bright night.
I bind AST branches, JIT kernels take flight,
Benchmarks tick, and alarms beep polite.
A rabbit-coded patch — small, clever, and light. 🐇✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title clearly summarizes the main addition of Tilelang JITv2 and highlights its key benefits—low overhead and enhanced syntax sugars—without extraneous details. It is concise, specific, and aligned with the core changes introduced in the pull request.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.
✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 14

🧹 Nitpick comments (9)
tilelang/profiler/bench.py (1)

101-110: Update docstring for renamed arg & new default

Docs still refer to fn and say the backend default is "event", but the signature now uses bench_fn and defaults to "cupti". Please sync the docstring so callers aren’t misled.

tilelang/language/allocate.py (2)

22-38: Consistent dtype normalization — minor doc update

Good switch to AnyDType + get_tvm_dtype and the bool shared-scope hack. Please update docstrings to reflect dtype: AnyDType (not str) to avoid confusion.


97-125: TMEM doc args mismatch with signature

Docstring refers to num_cols:int, but API takes shape and asserts 2D. Align docs (document expected 2D shape, column constraints) or change signature if you intended a scalar columns param.

tilelang/language/v2/ast_rewrite.py (1)

114-121: Exception type nit in _parse_names

Ruff suggests TypeError for invalid type. Optional tweak:

-            raise SyntaxError("Unsupported for target")
+            raise TypeError("Unsupported target type in for-loop")
tilelang/language/v2/lang.py (1)

198-206: Optional: empty() dtype typing

For consistency with AnyDType across the codebase, consider annotating dtype as AnyDType.

-def empty(
-    shape: "Tuple[*_Shapes]",
-    dtype: torch.dtype | tvm.DataType,
+def empty(
+    shape: "Tuple[*_Shapes]",
+    dtype: AnyDType,
tilelang/language/v2/__init__.py (1)

1-33: noqa noise and wildcard import

Numerous noqa: F401 markers and a wildcard import trigger Ruff noise. Prefer explicit re-exports (all) and drop redundant noqas if F401 isn’t enabled in CI.

-from .v1 import *  # noqa: F401
+# Prefer: from .v1 import __all__ as _v1_all; from .v1 import *; __all__ = [*locals()['__all__'], *_v1_all]
+# Or explicitly enumerate what you intend to re-export.
tilelang/language/v2/jit.py (3)

398-404: AutoTuner.run: unused dyn_args

Avoid unused binding.

-            const_args, dyn_args = self.arg_parser(*cfg.args, **cfg.kwargs)  # type: ignore
+            const_args, _ = self.arg_parser(*cfg.args, **cfg.kwargs)  # type: ignore

553-559: Unreachable code after raise in partial()

Code after raise is dead. Either implement auto-tune fallback or remove the unreachable lines.

-        if has_tune(const_args):
-            raise NotImplementedError("Please manually use ker.tune to run autotuner")
-            result = self.tune(*args, **kws)
-            assert not has_tune(result.best_args.args)
-            assert not has_tune(result.best_args.kwargs.values())
-            return self.partial(*result.best_args.args, **result.best_args.kwargs)
+        if has_tune(const_args):
+            raise NotImplementedError("Please manually use ker.tune to run autotuner")

700-719: Optional types in overloads

Type hints use bare None defaults; annotate as Optional[...] to satisfy strict type checkers.

-    target_host: Union[str, Target] = None,
+    target_host: Optional[Union[str, Target]] = None,
@@
-    pass_configs: Dict[str, Any] = None,
-    compile_flags: List[str] = None,
+    pass_configs: Optional[Dict[str, Any]] = None,
+    compile_flags: Optional[List[str]] = None,
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 340bfc5 and 1dc0d0f.

📒 Files selected for processing (13)
  • tilelang/language/__init__.py (2 hunks)
  • tilelang/language/allocate.py (8 hunks)
  • tilelang/language/customize.py (3 hunks)
  • tilelang/language/dtypes.py (1 hunks)
  • tilelang/language/proxy.py (2 hunks)
  • tilelang/language/tir/op.py (28 hunks)
  • tilelang/language/v2/__init__.py (1 hunks)
  • tilelang/language/v2/ast_rewrite.py (1 hunks)
  • tilelang/language/v2/compile.py (1 hunks)
  • tilelang/language/v2/jit.py (1 hunks)
  • tilelang/language/v2/lang.py (1 hunks)
  • tilelang/language/v2/v1.py (1 hunks)
  • tilelang/profiler/bench.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (13)
tilelang/language/customize.py (1)
tilelang/language/dtypes.py (1)
  • get_tvm_dtype (112-121)
tilelang/language/tir/op.py (1)
tilelang/language/dtypes.py (1)
  • get_tvm_dtype_str (124-127)
tilelang/language/v2/ast_rewrite.py (2)
tilelang/language/ast/ir.py (2)
  • If (1096-1112)
  • Assert (859-877)
tilelang/language/v2/compile.py (2)
  • ctx (706-710)
  • arg (540-568)
tilelang/language/v2/v1.py (18)
tilelang/language/__init__.py (5)
  • annotate_layout (114-152)
  • annotate_padding (155-189)
  • annotate_l2_hit_ratio (192-206)
  • import_source (209-211)
  • use_swizzle (105-111)
tilelang/language/parallel.py (1)
  • Parallel (8-28)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/persistent.py (1)
  • Persistent (8-27)
tilelang/language/frame.py (2)
  • has_let_value (189-198)
  • get_let_value (201-210)
tilelang/language/kernel.py (1)
  • KernelLaunchFrame (95-226)
tilelang/language/allocate.py (5)
  • alloc_local (41-53)
  • alloc_shared (22-38)
  • alloc_fragment (56-68)
  • alloc_barrier (85-94)
  • alloc_reducer (127-162)
tilelang/language/copy.py (2)
  • copy (10-86)
  • c2d_im2col (89-120)
tilelang/language/gemm.py (1)
  • gemm_v2 (216-428)
tilelang/language/experimental/gemm_sp.py (1)
  • gemm_sp (9-86)
tilelang/language/fill.py (2)
  • fill (9-21)
  • clear (24-48)
tilelang/language/reduce.py (6)
  • reduce_max (50-68)
  • reduce_min (71-84)
  • reduce_sum (87-109)
  • reduce_abssum (112-124)
  • reduce_absmax (127-139)
  • finalize_reducer (210-227)
tilelang/language/customize.py (3)
  • dp4a (10-21)
  • clamp (24-37)
  • view (53-66)
tilelang/language/atomic.py (2)
  • atomic_load (307-343)
  • atomic_store (346-396)
tilelang/language/logical.py (2)
  • any_of (10-42)
  • all_of (45-77)
tilelang/language/utils.py (1)
  • index_to_coordinates (91-110)
tilelang/language/dtypes.py (4)
  • get_cffi_dtype (143-145)
  • get_ctypes_dtype (138-140)
  • get_tvm_dtype (112-121)
  • get_tvm_ptr_type (148-151)
tilelang/language/proxy.py (1)
  • make_tensor (304-309)
tilelang/profiler/bench.py (2)
tilelang/autotuner/tuner.py (4)
  • TimeoutException (36-37)
  • timeout_handler (40-41)
  • run_with_timeout (44-53)
  • func (330-334)
tilelang/profiler/__init__.py (2)
  • func (284-286)
  • do_bench (218-281)
tilelang/language/__init__.py (1)
tilelang/language/dtypes.py (1)
  • get_tvm_dtype (112-121)
tilelang/language/dtypes.py (1)
tilelang/language/v2/lang.py (1)
  • dtype (111-112)
tilelang/language/proxy.py (2)
tilelang/language/dtypes.py (1)
  • get_tvm_dtype (112-121)
tilelang/language/ast/ir.py (1)
  • handle (1467-1497)
tilelang/language/allocate.py (2)
tilelang/language/dtypes.py (1)
  • get_tvm_dtype (112-121)
tilelang/language/ast/ir.py (1)
  • alloc_buffer (441-508)
tilelang/language/v2/__init__.py (4)
tilelang/language/v2/lang.py (13)
  • empty_data_ptr (248-251)
  • DynSchema (23-31)
  • ConstSchema (35-38)
  • TensorSchema (53-65)
  • StridedTensorSchema (42-50)
  • tune (239-245)
  • Tune (212-222)
  • dyn (142-143)
  • StridedTensor (169-170)
  • Tensor (172-173)
  • empty (198-208)
  • MakeEmpty (191-195)
  • place (268-281)
tilelang/language/v2/jit.py (10)
  • tune (510-536)
  • compile (178-179)
  • compile (562-578)
  • jit (696-697)
  • jit (701-708)
  • jit (711-739)
  • JITDispatcher (437-673)
  • par_compile (211-221)
  • par_compile (614-646)
  • macro (691-692)
tilelang/language/proxy.py (3)
  • ptr (277-301)
  • StridedTensor (258-259)
  • Tensor (255-256)
tilelang/language/v2/compile.py (10)
  • set_pass_configs (820-821)
  • get_pass_configs (476-477)
  • get_pass_configs (824-825)
  • set_compile_flags (828-829)
  • add_compile_flags (832-833)
  • get_compile_flags (479-480)
  • get_compile_flags (836-837)
  • get_params (840-841)
  • JITFunc (214-341)
  • macro (454-459)
tilelang/language/v2/compile.py (6)
tilelang/language/kernel.py (1)
  • KernelLaunchFrame (95-226)
tilelang/language/v2/ast_rewrite.py (1)
  • DSLMutator (87-310)
tilelang/language/v2/lang.py (17)
  • DynSchema (23-31)
  • StridedTensorSchema (42-50)
  • ConstSchema (35-38)
  • MakeEmpty (191-195)
  • Place (255-265)
  • _param (75-76)
  • TensorV2 (94-137)
  • name (99-100)
  • shape (103-104)
  • stride (264-265)
  • dtype (111-112)
  • strides (107-108)
  • params (121-122)
  • params (130-131)
  • params (161-162)
  • empty (198-208)
  • Tensor (172-173)
tilelang/language/dtypes.py (5)
  • get_tvm_dtype (112-121)
  • get_torch_dtype (130-135)
  • get_tvm_ptr_type (148-151)
  • get_cffi_dtype (143-145)
  • get_ctypes_dtype (138-140)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-104)
tilelang/language/ast/ir.py (8)
  • target (1682-1713)
  • decl_buffer (1137-1205)
  • let (921-956)
  • evaluate (1319-1331)
  • buffer_store (1263-1300)
  • If (1096-1112)
  • Then (1115-1123)
  • Else (1126-1134)
tilelang/language/v2/jit.py (4)
tilelang/jit/adapter/cython/adapter.py (1)
  • CythonKernelAdapter (179-526)
tilelang/language/v2/compile.py (19)
  • make_prim_func_generator (795-805)
  • generate_arg_parser (76-174)
  • DSLBuilder (422-764)
  • JITFunc (214-341)
  • JITPyFunc (44-49)
  • JITArgParser (52-59)
  • parse_args (322-323)
  • prim_func (501-502)
  • out_idx (229-230)
  • get_cffi_sig (255-264)
  • generate_global_alloc_wrapper (266-320)
  • arg (540-568)
  • get_pass_configs (476-477)
  • get_pass_configs (824-825)
  • get_compile_flags (479-480)
  • get_compile_flags (836-837)
  • get (473-474)
  • get_global_allocs (467-468)
  • get_outs (470-471)
tilelang/utils/target.py (1)
  • determine_target (54-99)
tilelang/language/v2/lang.py (9)
  • Tune (212-222)
  • TuneMany (226-236)
  • Place (255-265)
  • Tensor (172-173)
  • shape (103-104)
  • dtype (111-112)
  • strides (107-108)
  • stride (264-265)
  • name (99-100)
tilelang/language/v2/lang.py (3)
tilelang/language/dtypes.py (1)
  • VoidPtr (8-9)
tilelang/language/proxy.py (3)
  • StridedTensor (258-259)
  • Tensor (255-256)
  • ptr (277-301)
tilelang/language/v2/jit.py (1)
  • tune (510-536)
🪛 Ruff (0.13.3)
tilelang/language/v2/ast_rewrite.py

120-120: Prefer TypeError exception for invalid type

(TRY004)


120-120: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/v2/v1.py

2-2: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


5-5: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


6-6: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


8-8: from tilelang.language.tir.ir import * used; unable to detect undefined names

(F403)


8-8: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


9-9: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


10-10: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


11-11: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


12-12: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


13-13: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


15-15: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


16-16: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


17-17: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


18-18: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


19-19: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


20-20: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


22-22: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


24-24: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


25-25: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


26-26: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


27-27: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


28-28: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


29-29: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


31-31: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


32-32: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


33-33: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


34-34: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


36-36: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


37-37: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


38-38: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


39-39: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


40-40: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


41-41: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


42-42: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


43-43: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


45-45: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


47-47: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


48-48: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


49-49: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


50-50: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


51-51: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


52-52: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


53-53: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


54-54: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


55-55: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


56-56: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


57-57: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


59-59: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


60-60: from tilelang.language.builtin import * used; unable to detect undefined names

(F403)


60-60: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


61-61: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


63-63: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


64-64: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


65-65: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


66-66: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


67-67: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


68-68: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


70-70: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/profiler/bench.py

15-15: Unused function argument: signum

(ARG001)


15-15: Unused function argument: frame

(ARG001)


16-16: Avoid specifying long messages outside the exception class

(TRY003)


24-25: Remove exception handler; error is immediately re-raised

(TRY203)


25-25: Use raise without specifying exception name

Remove exception name

(TRY201)

tilelang/language/__init__.py

84-84: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


85-85: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


86-86: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/language/proxy.py

307-307: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

tilelang/language/v2/__init__.py

2-2: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


3-3: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


4-4: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


5-5: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


6-6: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


7-7: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


8-8: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


9-9: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


10-10: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


11-11: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


12-12: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


13-13: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


14-14: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


15-15: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


18-18: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


19-19: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


20-20: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


21-21: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


22-22: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


23-23: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


26-26: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


27-27: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


28-28: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


29-29: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


30-30: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


31-31: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


33-33: from .v1 import * used; unable to detect undefined names

(F403)


33-33: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/language/v2/compile.py

168-168: Use of exec detected

(S102)


203-203: Avoid specifying long messages outside the exception class

(TRY003)


236-236: Avoid specifying long messages outside the exception class

(TRY003)


251-251: Prefer TypeError exception for invalid type

(TRY004)


251-251: Avoid specifying long messages outside the exception class

(TRY003)


262-262: Prefer TypeError exception for invalid type

(TRY004)


262-262: Avoid specifying long messages outside the exception class

(TRY003)


292-292: Prefer TypeError exception for invalid type

(TRY004)


292-292: Avoid specifying long messages outside the exception class

(TRY003)


317-317: Use of exec detected

(S102)


329-329: Use explicit conversion flag

Replace with conversion flag

(RUF010)


330-330: Use explicit conversion flag

Replace with conversion flag

(RUF010)


331-331: Use explicit conversion flag

Replace with conversion flag

(RUF010)


332-332: Use explicit conversion flag

Replace with conversion flag

(RUF010)


333-333: Use explicit conversion flag

Replace with conversion flag

(RUF010)


334-334: Use explicit conversion flag

Replace with conversion flag

(RUF010)


335-335: Use explicit conversion flag

Replace with conversion flag

(RUF010)


336-336: Use explicit conversion flag

Replace with conversion flag

(RUF010)


418-418: Avoid specifying long messages outside the exception class

(TRY003)


563-563: Unpacked variable arg is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


570-570: Unused method argument: annot

(ARG002)


578-579: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)


579-580: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)


588-590: Avoid specifying long messages outside the exception class

(TRY003)


597-598: Avoid specifying long messages outside the exception class

(TRY003)


637-637: Avoid specifying long messages outside the exception class

(TRY003)


750-750: Avoid specifying long messages outside the exception class

(TRY003)


755-755: Avoid specifying long messages outside the exception class

(TRY003)


788-788: Use of exec detected

(S102)

tilelang/language/v2/jit.py

68-68: Use explicit conversion flag

Replace with conversion flag

(RUF010)


69-69: Use explicit conversion flag

Replace with conversion flag

(RUF010)


71-71: Use explicit conversion flag

Replace with conversion flag

(RUF010)


72-72: Use explicit conversion flag

Replace with conversion flag

(RUF010)


114-114: Avoid specifying long messages outside the exception class

(TRY003)


164-164: Avoid specifying long messages outside the exception class

(TRY003)


193-193: Use explicit conversion flag

Replace with conversion flag

(RUF010)


195-195: Use raise without specifying exception name

Remove exception name

(TRY201)


215-215: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


277-277: Use explicit conversion flag

Replace with conversion flag

(RUF010)


280-280: Use explicit conversion flag

Replace with conversion flag

(RUF010)


280-280: Use explicit conversion flag

Replace with conversion flag

(RUF010)


287-287: Use explicit conversion flag

Replace with conversion flag

(RUF010)


291-291: Use explicit conversion flag

Replace with conversion flag

(RUF010)


291-291: Use explicit conversion flag

Replace with conversion flag

(RUF010)


293-293: Use explicit conversion flag

Replace with conversion flag

(RUF010)


326-326: Use explicit conversion flag

Replace with conversion flag

(RUF010)


339-339: Do not catch blind exception: Exception

(BLE001)


366-366: Use raise without specifying exception name

Remove exception name

(TRY201)


368-368: Use explicit conversion flag

Replace with conversion flag

(RUF010)


400-400: Unpacked variable dyn_args is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


423-423: Do not catch blind exception: Exception

(BLE001)


426-426: Use explicit conversion flag

Replace with conversion flag

(RUF010)


444-444: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


445-445: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


616-616: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


618-618: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


636-636: Use explicit conversion flag

Replace with conversion flag

(RUF010)


638-638: Use raise without specifying exception name

Remove exception name

(TRY201)


705-705: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


706-706: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


715-715: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


716-716: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

tilelang/language/v2/lang.py

50-50: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-amd
  • GitHub Check: build-test-metal

Comment on lines +124 to +127
def get_tvm_dtype_str(ty: AnyDType) -> str:
if isinstance(ty, str):
return ty
return _dtype_torch2tvm[ty]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Handle tvm.DataType (and other AnyDType cases) before mapping.

get_tvm_dtype_str currently only returns early for str inputs. For every other allowed AnyDType (e.g., the tvm.DataType objects produced by get_tvm_dtype, ir.PrimType, or the new VoidPtr sentinel), this code immediately indexes _dtype_torch2tvm. Passing any of those values now raises KeyError, which breaks call sites such as call_intrin(get_tvm_dtype(torch.float32), ...) and any of the intrinsic wrappers updated in this PR. Please normalize non-string inputs first—for example, accept tvm.DataType/ir.PrimType by returning str(ty) (or ty.dtype), detect VoidPtr, and only fall back to _dtype_torch2tvm for Python/torch dtypes.

A minimal patch might look like:

 def get_tvm_dtype_str(ty: AnyDType) -> str:
-    if isinstance(ty, str):
-        return ty
-    return _dtype_torch2tvm[ty]
+    if isinstance(ty, str):
+        return ty
+    if isinstance(ty, tvm.DataType):
+        return str(ty)
+    if isinstance(ty, ir.PrimType):
+        return str(ty)
+    if isinstance(ty, VoidPtr) or ty == VoidPtr:
+        return "handle"
+    return _dtype_torch2tvm[ty]

(Feel free to tailor the exact pointer handling, but the key is to cover all members of AnyDType without throwing.)

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_tvm_dtype_str(ty: AnyDType) -> str:
if isinstance(ty, str):
return ty
return _dtype_torch2tvm[ty]
def get_tvm_dtype_str(ty: AnyDType) -> str:
if isinstance(ty, str):
return ty
if isinstance(ty, tvm.DataType):
return str(ty)
if isinstance(ty, ir.PrimType):
return str(ty)
if isinstance(ty, VoidPtr) or ty == VoidPtr:
return "handle"
return _dtype_torch2tvm[ty]
🤖 Prompt for AI Agents
In tilelang/language/dtypes.py around lines 124 to 127, get_tvm_dtype_str
currently only special-cases str and then indexes _dtype_torch2tvm, which raises
KeyError for tvm.DataType, ir.PrimType, VoidPtr, and other AnyDType variants;
change it to first normalize non-string inputs: if ty is tvm.DataType or
ir.PrimType (or has a .dtype attr) return str(ty) or ty.dtype as appropriate,
detect the VoidPtr sentinel and return the proper pointer/type string, and only
if ty is a Python/torch dtype fall back to looking up _dtype_torch2tvm; this
prevents KeyError and covers all AnyDType members before mapping.

Comment on lines +11 to +15
def ast_get_span(ast: ast.AST) -> Tuple[int, int, int, int]:
if not ast_has_span(ast):
return None
return tuple(getattr(ast, attr) for attr in _span_attrs)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

ast_get_span returns None but annotated non-Optional

The function can return None; fix the return type.

-def ast_get_span(ast: ast.AST) -> Tuple[int, int, int, int]:
+def ast_get_span(ast: ast.AST) -> Optional[Tuple[int, int, int, int]]:
     if not ast_has_span(ast):
         return None
     return tuple(getattr(ast, attr) for attr in _span_attrs)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def ast_get_span(ast: ast.AST) -> Tuple[int, int, int, int]:
if not ast_has_span(ast):
return None
return tuple(getattr(ast, attr) for attr in _span_attrs)
def ast_get_span(ast: ast.AST) -> Optional[Tuple[int, int, int, int]]:
if not ast_has_span(ast):
return None
return tuple(getattr(ast, attr) for attr in _span_attrs)
🤖 Prompt for AI Agents
In tilelang/language/v2/ast_rewrite.py around lines 11 to 15, ast_get_span is
annotated to return Tuple[int, int, int, int] but can return None; change the
return type to Optional[Tuple[int, int, int, int]] and add the appropriate
import from typing (Optional) at the top of the file. Update the function
signature to use Optional[...] and leave the implementation as-is; also scan
callers for any typing assumptions and handle the Optional return where
necessary.

Comment on lines +17 to +22
def ast_set_span(ast: ast.AST, span: Tuple[int, int, int, int]):
if not ast_has_span(ast):
return
for attr, value in zip(_span_attrs, span):
setattr(ast, attr, value)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

ast_set_span never sets on fresh nodes

Guard prevents spans from being set on freshly-built AST nodes (hasattr is False initially). Always set these attributes.

-def ast_set_span(ast: ast.AST, span: Tuple[int, int, int, int]):
-    if not ast_has_span(ast):
-        return
-    for attr, value in zip(_span_attrs, span):
-        setattr(ast, attr, value)
+def ast_set_span(ast: ast.AST, span: Tuple[int, int, int, int]):
+    for attr, value in zip(_span_attrs, span):
+        setattr(ast, attr, value)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def ast_set_span(ast: ast.AST, span: Tuple[int, int, int, int]):
if not ast_has_span(ast):
return
for attr, value in zip(_span_attrs, span):
setattr(ast, attr, value)
def ast_set_span(ast: ast.AST, span: Tuple[int, int, int, int]):
for attr, value in zip(_span_attrs, span):
setattr(ast, attr, value)
🤖 Prompt for AI Agents
In tilelang/language/v2/ast_rewrite.py around lines 17 to 22, the current
early-return prevents spans from being applied to newly created AST nodes;
remove the guard and always set the span attributes. Replace the body so it
simply iterates over _span_attrs and the provided span tuple and calls
setattr(ast, attr, value) for each pair (no hasattr check), ensuring
freshly-built nodes get their span fields populated.

Comment on lines +61 to +65
def quote_expr(expr: str, **kws) -> List[ast.AST]:
res = quote1(expr, **kws)
assert isinstance(res, ast.Expr)
return res.value

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

quote_expr return type

Returns an expression node, not a list.

-def quote_expr(expr: str, **kws) -> List[ast.AST]:
+def quote_expr(expr: str, **kws) -> ast.AST:
     res = quote1(expr, **kws)
     assert isinstance(res, ast.Expr)
     return res.value
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def quote_expr(expr: str, **kws) -> List[ast.AST]:
res = quote1(expr, **kws)
assert isinstance(res, ast.Expr)
return res.value
def quote_expr(expr: str, **kws) -> ast.AST:
res = quote1(expr, **kws)
assert isinstance(res, ast.Expr)
return res.value
🤖 Prompt for AI Agents
In tilelang/language/v2/ast_rewrite.py around lines 61 to 65, the function
quote_expr is annotated to return List[ast.AST] but actually returns a single
expression node (res.value); update the function signature to return ast.AST
(not List[ast.AST]) and adjust any imports/typing references accordingly, and
scan/adjust callers if they expect a list so they handle a single ast.AST
instead.

Comment on lines +231 to +244
all_args = node.args.posonlyargs + node.args.args
if node.args.vararg is not None:
all_args += node.args.vararg
all_args += node.args.kwonlyargs
stmts = []
for arg in all_args:
name = arg.arg
if arg.annotation is not None:
arg_stmt = quote1(
f'{name} = __tb.arg("{name}", {name}, annot)', annot=arg.annotation, span=arg)
else:
arg_stmt = quote1(f'{name} = __tb.arg("{name}", {name})', span=arg)
stmts.append(arg_stmt)
node.decorator_list.pop(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

FunctionDef vararg handling bug (TypeError at runtime)

all_args += node.args.vararg tries to add ast.arg to list. Use append.

-        if node.args.vararg is not None:
-            all_args += node.args.vararg
+        if node.args.vararg is not None:
+            all_args.append(node.args.vararg)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
all_args = node.args.posonlyargs + node.args.args
if node.args.vararg is not None:
all_args += node.args.vararg
all_args += node.args.kwonlyargs
stmts = []
for arg in all_args:
name = arg.arg
if arg.annotation is not None:
arg_stmt = quote1(
f'{name} = __tb.arg("{name}", {name}, annot)', annot=arg.annotation, span=arg)
else:
arg_stmt = quote1(f'{name} = __tb.arg("{name}", {name})', span=arg)
stmts.append(arg_stmt)
node.decorator_list.pop(0)
all_args = node.args.posonlyargs + node.args.args
if node.args.vararg is not None:
all_args.append(node.args.vararg)
all_args += node.args.kwonlyargs
stmts = []
for arg in all_args:
name = arg.arg
if arg.annotation is not None:
arg_stmt = quote1(
f'{name} = __tb.arg("{name}", {name}, annot)', annot=arg.annotation, span=arg)
else:
arg_stmt = quote1(f'{name} = __tb.arg("{name}", {name})', span=arg)
stmts.append(arg_stmt)
node.decorator_list.pop(0)
🤖 Prompt for AI Agents
In tilelang/language/v2/ast_rewrite.py around lines 231 to 244, the code uses
all_args += node.args.vararg which attempts to concatenate an ast.arg object to
a list and raises a TypeError at runtime; change that to
all_args.append(node.args.vararg) (keeping the existing "if node.args.vararg is
not None" guard) so the vararg ast.arg is added as a single element to the
all_args list.

Comment on lines +130 to +132
def params(self):
return *self.shape_params(), self.dtype

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Syntax error: starred return in params()

return *... is invalid in Python. Wrap in a tuple literal.

-    def params(self):
-        return *self.shape_params(), self.dtype
+    def params(self):
+        return (*self.shape_params(), self.dtype)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def params(self):
return *self.shape_params(), self.dtype
def params(self):
return (*self.shape_params(), self.dtype)
🤖 Prompt for AI Agents
In tilelang/language/v2/lang.py around lines 130 to 132, the params() method
currently uses an invalid starred return (return *self.shape_params(),
self.dtype); change it to return a tuple by wrapping the expanded values in
parentheses — e.g. return (*self.shape_params(), self.dtype) — so the shape
parameters are expanded into the returned tuple along with dtype.

Comment on lines +221 to +223
def __ne__(self, rhs) -> bool:
return isinstance(rhs, Tune) and self.data == rhs.data

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Tune.ne implements equality, not inequality

ne should be logical negation of eq.

-    def __ne__(self, rhs) -> bool:
-        return isinstance(rhs, Tune) and self.data == rhs.data
+    def __ne__(self, rhs) -> bool:
+        return not self.__eq__(rhs)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def __ne__(self, rhs) -> bool:
return isinstance(rhs, Tune) and self.data == rhs.data
def __ne__(self, rhs) -> bool:
return not self.__eq__(rhs)
🤖 Prompt for AI Agents
In tilelang/language/v2/lang.py around lines 221 to 223, __ne__ currently
returns the same result as equality; change it to be the logical negation of
__eq__. Implement __ne__ to return not self.__eq__(rhs) (or equivalently return
not (isinstance(rhs, Tune) and self.data == rhs.data)) so inequality correctly
mirrors equality.

Comment on lines +279 to +281
if device is None:
device = torch.cuda.current_device()
return Place(shape=shape, strides=strides, dtype=dtype, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

place(): device type mismatch (int vs torch.device)

torch.cuda.current_device() returns int; Place.device expects torch.device and is compared against torch.device in arg parsing. Use torch.device("cuda") for consistency.

-    if device is None:
-        device = torch.cuda.current_device()
+    if device is None:
+        device = torch.device("cuda")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if device is None:
device = torch.cuda.current_device()
return Place(shape=shape, strides=strides, dtype=dtype, device=device)
if device is None:
device = torch.device("cuda")
return Place(shape=shape, strides=strides, dtype=dtype, device=device)
🤖 Prompt for AI Agents
In tilelang/language/v2/lang.py around lines 279 to 281, the code sets device =
torch.cuda.current_device() which returns an int while Place.device expects a
torch.device; change the fallback to construct a torch.device (e.g. device =
torch.device("cuda") or torch.device(f"cuda:{torch.cuda.current_device()}") if
you need the specific GPU index) so Place receives a torch.device instance and
comparisons in arg parsing work correctly.

Comment on lines +289 to +327
def _apply_tvm_patches():

def __array_eq(self, rhs):
if isinstance(rhs, tuple):
return tuple(self) == rhs
if isinstance(rhs, list):
return list(self) == rhs
if isinstance(rhs, tvm.ffi.container.Array):
return tvm.core.Object.__eq__(self, rhs)

def __array_ne(self, rhs):
if isinstance(rhs, tuple):
return tuple(self) != rhs
if isinstance(rhs, list):
return list(self) != rhs
if isinstance(rhs, tvm.ffi.container.Array):
return tvm.core.Object.__ne__(self, rhs)

def __array_req(self, lhs):
if isinstance(lhs, tuple):
return tuple(self) == lhs
if isinstance(lhs, list):
return list(self) == lhs
if isinstance(lhs, tvm.ffi.container.Array):
return tvm.core.Object.__eq__(self, lhs)

def __array_rne(self, rhs):
if isinstance(rhs, tuple):
return tuple(self) != rhs
if isinstance(rhs, list):
return list(self) != rhs
if isinstance(rhs, tvm.ffi.container.Array):
return tvm.core.Object.__ne__(self, rhs)

tvm.ffi.container.Array.__eq__ = __array_eq
tvm.ffi.container.Array.__ne__ = __array_ne
tvm.ffi.container.Array.__req__ = __array_req
tvm.ffi.container.Array.__rne__ = __array_rne

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Custom tvm Array equality should return NotImplemented by default

If rhs/lhs is an unsupported type, return NotImplemented instead of falling through (None). Prevents surprising behavior in equality.

     def __array_eq(self, rhs):
         if isinstance(rhs, tuple):
             return tuple(self) == rhs
         if isinstance(rhs, list):
             return list(self) == rhs
         if isinstance(rhs, tvm.ffi.container.Array):
             return tvm.core.Object.__eq__(self, rhs)
+        return NotImplemented
 
     def __array_ne(self, rhs):
         if isinstance(rhs, tuple):
             return tuple(self) != rhs
         if isinstance(rhs, list):
             return list(self) != rhs
         if isinstance(rhs, tvm.ffi.container.Array):
             return tvm.core.Object.__ne__(self, rhs)
+        return NotImplemented
 
     def __array_req(self, lhs):
         if isinstance(lhs, tuple):
             return tuple(self) == lhs
         if isinstance(lhs, list):
             return list(self) == lhs
         if isinstance(lhs, tvm.ffi.container.Array):
             return tvm.core.Object.__eq__(self, lhs)
+        return NotImplemented
 
     def __array_rne(self, rhs):
         if isinstance(rhs, tuple):
             return tuple(self) != rhs
         if isinstance(rhs, list):
             return list(self) != rhs
         if isinstance(rhs, tvm.ffi.container.Array):
             return tvm.core.Object.__ne__(self, rhs)
+        return NotImplemented
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _apply_tvm_patches():
def __array_eq(self, rhs):
if isinstance(rhs, tuple):
return tuple(self) == rhs
if isinstance(rhs, list):
return list(self) == rhs
if isinstance(rhs, tvm.ffi.container.Array):
return tvm.core.Object.__eq__(self, rhs)
def __array_ne(self, rhs):
if isinstance(rhs, tuple):
return tuple(self) != rhs
if isinstance(rhs, list):
return list(self) != rhs
if isinstance(rhs, tvm.ffi.container.Array):
return tvm.core.Object.__ne__(self, rhs)
def __array_req(self, lhs):
if isinstance(lhs, tuple):
return tuple(self) == lhs
if isinstance(lhs, list):
return list(self) == lhs
if isinstance(lhs, tvm.ffi.container.Array):
return tvm.core.Object.__eq__(self, lhs)
def __array_rne(self, rhs):
if isinstance(rhs, tuple):
return tuple(self) != rhs
if isinstance(rhs, list):
return list(self) != rhs
if isinstance(rhs, tvm.ffi.container.Array):
return tvm.core.Object.__ne__(self, rhs)
tvm.ffi.container.Array.__eq__ = __array_eq
tvm.ffi.container.Array.__ne__ = __array_ne
tvm.ffi.container.Array.__req__ = __array_req
tvm.ffi.container.Array.__rne__ = __array_rne
def _apply_tvm_patches():
def __array_eq(self, rhs):
if isinstance(rhs, tuple):
return tuple(self) == rhs
if isinstance(rhs, list):
return list(self) == rhs
if isinstance(rhs, tvm.ffi.container.Array):
return tvm.core.Object.__eq__(self, rhs)
return NotImplemented
def __array_ne(self, rhs):
if isinstance(rhs, tuple):
return tuple(self) != rhs
if isinstance(rhs, list):
return list(self) != rhs
if isinstance(rhs, tvm.ffi.container.Array):
return tvm.core.Object.__ne__(self, rhs)
return NotImplemented
def __array_req(self, lhs):
if isinstance(lhs, tuple):
return tuple(self) == lhs
if isinstance(lhs, list):
return list(self) == lhs
if isinstance(lhs, tvm.ffi.container.Array):
return tvm.core.Object.__eq__(self, lhs)
return NotImplemented
def __array_rne(self, rhs):
if isinstance(rhs, tuple):
return tuple(self) != rhs
if isinstance(rhs, list):
return list(self) != rhs
if isinstance(rhs, tvm.ffi.container.Array):
return tvm.core.Object.__ne__(self, rhs)
return NotImplemented
tvm.ffi.container.Array.__eq__ = __array_eq
tvm.ffi.container.Array.__ne__ = __array_ne
tvm.ffi.container.Array.__req__ = __array_req
tvm.ffi.container.Array.__rne__ = __array_rne
🤖 Prompt for AI Agents
In tilelang/language/v2/lang.py around lines 289-327, the custom tvm Array
equality and inequality helpers currently fall through and implicitly return
None for unsupported types; change each helper (__array_eq, __array_ne,
__array_req, __array_rne) so that after the existing isinstance checks they
explicitly return NotImplemented when rhs/lhs is an unsupported type, ensuring
Python's rich comparison machinery handles those cases correctly.

Comment on lines +19 to +27
def run_with_timeout(func, timeout, *args, **kwargs):
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout)
try:
result = func(*args, **kwargs)
except Exception as e:
raise e
finally:
signal.alarm(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Timeout impl crashes on Windows

signal.SIGALRM doesn’t exist on Windows, so any call to run_with_timeout will raise AttributeError. The new timeout option therefore breaks outright on that platform. Guard the feature (hasattr(signal, "SIGALRM")), fall back gracefully (or raise a clear OS error), and restore the previous handler after the alarm:

 def run_with_timeout(func, timeout, *args, **kwargs):
-    signal.signal(signal.SIGALRM, timeout_handler)
-    signal.alarm(timeout)
+    if not hasattr(signal, "SIGALRM"):
+        raise RuntimeError("Timeout-based benchmarking is unsupported on this platform")
+    previous_handler = signal.getsignal(signal.SIGALRM)
+    signal.signal(signal.SIGALRM, timeout_handler)
+    signal.setitimer(signal.ITIMER_REAL, timeout)
     try:
         result = func(*args, **kwargs)
     except Exception as e:
         raise e
     finally:
-        signal.alarm(0)
+        signal.setitimer(signal.ITIMER_REAL, 0)
+        signal.signal(signal.SIGALRM, previous_handler)
     return result
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def run_with_timeout(func, timeout, *args, **kwargs):
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout)
try:
result = func(*args, **kwargs)
except Exception as e:
raise e
finally:
signal.alarm(0)
def run_with_timeout(func, timeout, *args, **kwargs):
if not hasattr(signal, "SIGALRM"):
raise RuntimeError("Timeout-based benchmarking is unsupported on this platform")
previous_handler = signal.getsignal(signal.SIGALRM)
signal.signal(signal.SIGALRM, timeout_handler)
signal.setitimer(signal.ITIMER_REAL, timeout)
try:
result = func(*args, **kwargs)
except Exception as e:
raise e
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
signal.signal(signal.SIGALRM, previous_handler)
return result
🧰 Tools
🪛 Ruff (0.13.3)

24-25: Remove exception handler; error is immediately re-raised

(TRY203)


25-25: Use raise without specifying exception name

Remove exception name

(TRY201)

🤖 Prompt for AI Agents
In tilelang/profiler/bench.py around lines 19 to 27, the timeout implementation
uses signal.SIGALRM which doesn't exist on Windows and will raise
AttributeError; update run_with_timeout to first check hasattr(signal,
"SIGALRM") and if absent either raise a clear OSError indicating timeouts are
unsupported on this platform or call the function without a timeout, and when
using SIGALRM save the previous signal handler before calling signal.signal, set
the alarm, wrap the function call in try/except/finally, cancel the alarm and
restore the previous handler in the finally block so the original signal state
is preserved.

@kurisu6912 kurisu6912 changed the title Jit v2 [Feature] Tilelang JITv2: Low Overhead and Syntax Sugars Oct 14, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Nitpick comments (5)
examples/jitv2/jitv2.ipynb (5)

331-333: Nit: Improve wording.

Duplicate word and clarity.

  • “use tl.StridedTensor to to use strided tensor” → “Use tl.StridedTensor for strided tensors”
  • Consider clarifying which axes/fields are static vs dynamic.

343-343: Rename get_contingous → get_contiguous (spelling).

Improves readability and consistency with PyTorch terminology.

-def get_contingous(
+def get_contiguous(
-out = get_contingous(A[:, 0, :, 0])
+out = get_contiguous(A[:, 0, :, 0])

Also applies to: 365-365


1787-1794: Validate heads/groups consistency and use heads_kv.

Add checks so grouped KV heads are coherent; also resolves the “unused heads_kv” static hint.

-    batch, seq_len, heads, dim_qk, dtype = Q.params()
-    batch, seq_len, heads_kv, dim_qk, dtype = K.params()
-    batch, seq_len, heads_kv, dim_v, dtype = V.params()
+    batch, seq_len, heads, dim_qk, dtype = Q.params()
+    batch, seq_len, heads_kv, dim_qk, dtype = K.params()
+    batch, seq_len, heads_kv2, dim_v, dtype = V.params()
+    assert heads_kv == heads_kv2, "K/V must have same heads_kv"
+    assert heads % groups == 0, "heads must be divisible by groups"
+    assert heads_kv == heads // groups, "heads_kv must equal heads // groups"

738-749: Fix typos in autotune section.

Minor wording fixes for clarity.

  • “Advanced Tunning” → “Advanced Tuning”
  • “kerl.get_tune_configs”/“kerl.tune_configs” → “ker.get_tune_configs”/“ker.tune_configs” (or match the API you present, e.g., gemm.get_tune_configs / gemm.tune_configs)

1136-1137: Clarify parallel compile wording.

Use accurate name and phrasing.

  • “Use ker.par_compile to compile many kernel args parallely” → “Use gemm.par_compile to compile many kernel args in parallel”
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1dc0d0f and 25c7d1c.

📒 Files selected for processing (1)
  • examples/jitv2/jitv2.ipynb (1 hunks)
🧰 Additional context used
🪛 Ruff (0.14.0)
examples/jitv2/jitv2.ipynb

232-232: Found useless expression. Either assign it to a variable or remove it.

(B018)


373-373: Found useless expression. Either assign it to a variable or remove it.

(B018)


380-380: Found useless expression. Either assign it to a variable or remove it.

(B018)


388-388: Found useless expression. Either assign it to a variable or remove it.

(B018)


512-512: Unpacked variable heads_kv is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-metal

Comment on lines +61 to +68
" N, K = A.shape_params()\n",
" M, K = B.shape_params()\n",
"\n",
" C = tl.empty((M, N), dtype=accum_dtype)\n",
" dims = [\n",
" tl.ceildiv(M, block_M),\n",
" tl.ceildiv(N, block_N),\n",
" ]\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix GEMM shape inference (A/B dims swapped) and add inner-dim assert.

Current code treats A as (N, K) and B as (M, K), but indexing/gemm expects A=(M, K), B=(K, N). This can silently produce wrong results. Adjust extraction and assert K match.

-    N, K = A.shape_params()
-    M, K = B.shape_params()
+    M, K = A.shape_params()
+    K2, N = B.shape_params()
+    assert K == K2, "Incompatible inner dimensions: A.shape[1] != B.shape[0]"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
" N, K = A.shape_params()\n",
" M, K = B.shape_params()\n",
"\n",
" C = tl.empty((M, N), dtype=accum_dtype)\n",
" dims = [\n",
" tl.ceildiv(M, block_M),\n",
" tl.ceildiv(N, block_N),\n",
" ]\n",
M, K = A.shape_params()
K2, N = B.shape_params()
assert K == K2, "Incompatible inner dimensions: A.shape[1] != B.shape[0]"
C = tl.empty((M, N), dtype=accum_dtype)
dims = [
tl.ceildiv(M, block_M),
tl.ceildiv(N, block_N),
]
🤖 Prompt for AI Agents
In examples/jitv2/jitv2.ipynb around lines 61 to 68, the code currently treats A
as (N,K) and B as (M,K) which is reversed for the GEMM implementation; update
the shape extraction to treat A as (M,K) and B as (K,N) (e.g., M,K =
A.shape_params(); K2,N = B.shape_params()), add an assertion that K == K2 to
validate the inner dimension matches, and keep C = tl.empty((M, N), ...) and the
dims computation unchanged so the output shape and tiling remain correct.

Comment on lines +174 to +176
" A = tl.make_tensor(A_ptr, (M, K), dtype=dtype)\n",
" B = tl.make_tensor(B_ptr, (N, K), dtype=dtype)\n",
" dims = [\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Correct B tensor shape in pointer-based GEMM.

Indexing uses B[k, bx*block_N], so B must be (K, N), not (N, K).

-    A = tl.make_tensor(A_ptr, (M, K), dtype=dtype)
-    B = tl.make_tensor(B_ptr, (N, K), dtype=dtype)
+    A = tl.make_tensor(A_ptr, (M, K), dtype=dtype)
+    B = tl.make_tensor(B_ptr, (K, N), dtype=dtype)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
" A = tl.make_tensor(A_ptr, (M, K), dtype=dtype)\n",
" B = tl.make_tensor(B_ptr, (N, K), dtype=dtype)\n",
" dims = [\n",
A = tl.make_tensor(A_ptr, (M, K), dtype=dtype)
B = tl.make_tensor(B_ptr, (K, N), dtype=dtype)
dims = [
🤖 Prompt for AI Agents
In examples/jitv2/jitv2.ipynb around lines 174 to 176, the B tensor is created
with shape (N, K) but the kernel indexes B as B[k, bx*block_N], so B must be
shaped (K, N); change the tl.make_tensor call for B to use (K, N) as its shape
(and update any nearby variable names/comments if necessary to reflect the
corrected shape).

Comment on lines +252 to +256
" with tl.Kernel(tl.ceildiv(N, block_N), threads=128) as bx:\n",
" px = bx * block_N\n",
" for i in tl.Parallel(block_N):\n",
" C[px + i] = A[px + i] + B[px + i]\n",
" return C"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard tail to avoid OOB in vec_add.

When N % block_N != 0, the last block accesses out-of-bounds.

-        for i in tl.Parallel(block_N):
-            C[px + i] = A[px + i] + B[px + i]
+        for i in tl.Parallel(block_N):
+            if px + i < N:
+                C[px + i] = A[px + i] + B[px + i]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
" with tl.Kernel(tl.ceildiv(N, block_N), threads=128) as bx:\n",
" px = bx * block_N\n",
" for i in tl.Parallel(block_N):\n",
" C[px + i] = A[px + i] + B[px + i]\n",
" return C"
with tl.Kernel(tl.ceildiv(N, block_N), threads=128) as bx:
px = bx * block_N
for i in tl.Parallel(block_N):
if px + i < N:
C[px + i] = A[px + i] + B[px + i]
return C
🤖 Prompt for AI Agents
In examples/jitv2/jitv2.ipynb around lines 252 to 256, the vec_add kernel
iterates full block_N across all blocks and will read/write past N when N %
block_N != 0; add a bounds guard for the tail: compute the global indices (px +
i) and mask them with a boolean valid = px + i < N, then use that mask to
conditionally load/store (or use tl.where/elementwise masking) so only in-bounds
lanes perform A+B and write to C; ensure the kernel returns C unchanged for
out-of-bounds lanes.

Comment on lines +297 to +301
" with tl.Kernel(tl.ceildiv(N, block_N), threads=128) as bx:\n",
" px = bx * block_N\n",
" for i in tl.Parallel(block_N):\n",
" C[px + i] = A[px + i] + cval\n",
" return C"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard tail to avoid OOB in vec_add_scalar.

Same tail issue as vec_add.

-        for i in tl.Parallel(block_N):
-            C[px + i] = A[px + i] + cval
+        for i in tl.Parallel(block_N):
+            if px + i < N:
+                C[px + i] = A[px + i] + cval
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
" with tl.Kernel(tl.ceildiv(N, block_N), threads=128) as bx:\n",
" px = bx * block_N\n",
" for i in tl.Parallel(block_N):\n",
" C[px + i] = A[px + i] + cval\n",
" return C"
with tl.Kernel(tl.ceildiv(N, block_N), threads=128) as bx:
px = bx * block_N
for i in tl.Parallel(block_N):
if px + i < N:
C[px + i] = A[px + i] + cval
return C
🤖 Prompt for AI Agents
In examples/jitv2/jitv2.ipynb around lines 297-301, the vec_add_scalar kernel
writes past the array end for the tail block; guard the tail by computing the
global index (idx = px + i) and only performing the write when idx < N (e.g., in
the loop add an if idx < N: C[idx] = A[idx] + cval or use a masked store),
ensuring out-of-bounds accesses are prevented.

Comment on lines +415 to +427
"torch_beg = time.perf_counter()\n",
"for _ in range(10000):\n",
" A @ B\n",
"torch_end = time.perf_counter()\n",
"elapsed = (torch_end - torch_beg) / 10000 * 1e6\n",
"\n",
"print('Torch time: ', elapsed, 'us')\n",
"\n",
"tl_beg = time.perf_counter()\n",
"for _ in range(10000):\n",
" gemm(A, B)\n",
"tl_end = time.perf_counter()\n",
"elapsed = (tl_end - tl_beg) / 10000 * 1e6\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Synchronize CUDA for accurate timing.

GPU launches are async; without synchronize(), the timings are unreliable.

- torch_beg = time.perf_counter()
- for _ in range(10000):
-     A @ B
- torch_end = time.perf_counter()
+ torch.cuda.synchronize()
+ torch_beg = time.perf_counter()
+ for _ in range(10000):
+     A @ B
+ torch.cuda.synchronize()
+ torch_end = time.perf_counter()
@@
- tl_beg = time.perf_counter()
- for _ in range(10000):
-     gemm(A, B)
- tl_end = time.perf_counter()
+ torch.cuda.synchronize()
+ tl_beg = time.perf_counter()
+ for _ in range(10000):
+     gemm(A, B)
+ torch.cuda.synchronize()
+ tl_end = time.perf_counter()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"torch_beg = time.perf_counter()\n",
"for _ in range(10000):\n",
" A @ B\n",
"torch_end = time.perf_counter()\n",
"elapsed = (torch_end - torch_beg) / 10000 * 1e6\n",
"\n",
"print('Torch time: ', elapsed, 'us')\n",
"\n",
"tl_beg = time.perf_counter()\n",
"for _ in range(10000):\n",
" gemm(A, B)\n",
"tl_end = time.perf_counter()\n",
"elapsed = (tl_end - tl_beg) / 10000 * 1e6\n",
# Ensure any prior GPU work is finished
torch.cuda.synchronize()
torch_beg = time.perf_counter()
for _ in range(10000):
A @ B
# Wait for the last A @ B to complete
torch.cuda.synchronize()
torch_end = time.perf_counter()
elapsed = (torch_end - torch_beg) / 10000 * 1e6
print('Torch time: ', elapsed, 'us')
# Repeat for the Triton gemm
torch.cuda.synchronize()
tl_beg = time.perf_counter()
for _ in range(10000):
gemm(A, B)
# Wait for the last gemm to complete
torch.cuda.synchronize()
tl_end = time.perf_counter()
elapsed = (tl_end - tl_beg) / 10000 * 1e6
print('Triton time: ', elapsed, 'us')
🤖 Prompt for AI Agents
In examples/jitv2/jitv2.ipynb around lines 415-427, the timing around the GPU
matrix multiplies is unreliable because CUDA launches are asynchronous; add
explicit synchronization calls so the measured intervals include kernel
execution. Specifically, call torch.cuda.synchronize() (guarded by
torch.cuda.is_available()) immediately before starting each perf_counter() and
immediately after each timed loop (both for the A @ B loop and the gemm(A, B)
loop) so the timers bracket actual GPU work.

Comment on lines +455 to +467
"torch_beg = time.perf_counter()\n",
"for _ in range(10000):\n",
" A + B\n",
"torch_end = time.perf_counter()\n",
"elapsed = (torch_end - torch_beg) / 10000 * 1e6\n",
"\n",
"print('Torch time: ', elapsed, 'us')\n",
"\n",
"tl_beg = time.perf_counter()\n",
"for _ in range(10000):\n",
" vec_add(A, B)\n",
"tl_end = time.perf_counter()\n",
"elapsed = (tl_end - tl_beg) / 10000 * 1e6\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Synchronize CUDA for accurate timing (Vec Add).

Same issue as GEMM benchmark.

- torch_beg = time.perf_counter()
- for _ in range(10000):
-     A + B
- torch_end = time.perf_counter()
+ torch.cuda.synchronize()
+ torch_beg = time.perf_counter()
+ for _ in range(10000):
+     A + B
+ torch.cuda.synchronize()
+ torch_end = time.perf_counter()
@@
- tl_beg = time.perf_counter()
- for _ in range(10000):
-     vec_add(A, B)
- tl_end = time.perf_counter()
+ torch.cuda.synchronize()
+ tl_beg = time.perf_counter()
+ for _ in range(10000):
+     vec_add(A, B)
+ torch.cuda.synchronize()
+ tl_end = time.perf_counter()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"torch_beg = time.perf_counter()\n",
"for _ in range(10000):\n",
" A + B\n",
"torch_end = time.perf_counter()\n",
"elapsed = (torch_end - torch_beg) / 10000 * 1e6\n",
"\n",
"print('Torch time: ', elapsed, 'us')\n",
"\n",
"tl_beg = time.perf_counter()\n",
"for _ in range(10000):\n",
" vec_add(A, B)\n",
"tl_end = time.perf_counter()\n",
"elapsed = (tl_end - tl_beg) / 10000 * 1e6\n",
torch.cuda.synchronize()
torch_beg = time.perf_counter()
for _ in range(10000):
A + B
torch.cuda.synchronize()
torch_end = time.perf_counter()
elapsed = (torch_end - torch_beg) / 10000 * 1e6
print('Torch time: ', elapsed, 'us')
torch.cuda.synchronize()
tl_beg = time.perf_counter()
for _ in range(10000):
vec_add(A, B)
torch.cuda.synchronize()
tl_end = time.perf_counter()
elapsed = (tl_end - tl_beg) / 10000 * 1e6
🤖 Prompt for AI Agents
In examples/jitv2/jitv2.ipynb around lines 455 to 467, the Vec Add timing
measures asynchronous CUDA launches leading to incorrect timings; synchronize
the CUDA device before starting each timer and again after the loop (e.g.,
torch.cuda.synchronize()) so all GPU work is completed prior to reading the
clock, and apply the same synchronization pattern used in the GEMM benchmark to
both the PyTorch A+B loop and the vec_add(A, B) loop.

Comment on lines +950 to +952
" (M, K), (N, K2) = A.shape, B.shape\n",
" assert K == K2, \"Expect matrix A and B to have the same number of columns\"\n",
" C = tl.empty((M, N), dtype=accum_dtype)\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix shape destruct in gemm_tune_advanced (B dims order).

B is (K, N). Current destruct makes assert compare K with N.

-    (M, K), (N, K2) = A.shape, B.shape
-    assert K == K2, "Expect matrix A and B to have the same number of columns"
+    M, K = A.shape
+    K2, N = B.shape
+    assert K == K2, "Expect A.shape[1] == B.shape[0]"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
" (M, K), (N, K2) = A.shape, B.shape\n",
" assert K == K2, \"Expect matrix A and B to have the same number of columns\"\n",
" C = tl.empty((M, N), dtype=accum_dtype)\n",
M, K = A.shape
K2, N = B.shape
assert K == K2, "Expect A.shape[1] == B.shape[0]"
C = tl.empty((M, N), dtype=accum_dtype)
🤖 Prompt for AI Agents
In examples/jitv2/jitv2.ipynb around lines 950 to 952, the code destructures
B.shape as (N, K2) but B is actually (K, N); change the destructuring to "(M,
K), (K2, N) = A.shape, B.shape" and keep/ensure the assert as "assert K == K2,
..." so the shared dimension is compared correctly; no change needed to the
creation of C which should remain tl.empty((M, N), dtype=accum_dtype).

@LeiWang1999
Copy link
Member

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Comment on lines +157 to +163
source = ""
source += "def parse_args(" + ", ".join(closure.keys()) + "):\n"
source += f" def {fn_name}(" + ", ".join(func_args) + ", __stream__=None" + "):\n"
source += " " + "\n ".join(code_parse_arg) + "\n"
source += " __const_args__ = (" + ", ".join(tup_const) + ")\n"
source += " __dyn_args__ = (" + ", ".join(tup_dyn) + ", __stream__, __device__)\n"
source += " return __const_args__, __dyn_args__\n"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Handle zero dynamic arguments when generating arg parser

When generate_arg_parser builds the source string it always emits __dyn_args__ = ( …, __stream__, __device__) but it relies on ", ".join(tup_dyn) to insert the leading values. If a kernel has no dynamic parameters (no tensors or DynSchema arguments), tup_dyn is empty and the emitted code becomes __dyn_args__ = (, __stream__, __device__), which is invalid Python and causes a SyntaxError the first time such a kernel is decorated. Consider appending __stream__ and __device__ to the list before joining or adding a conditional to avoid producing an empty tuple prefix.

Useful? React with 👍 / 👎.

Comment on lines 118 to 126
# Initial function call and synchronization
def fn():
if timeout is not None:
run_with_timeout(bench_fn, timeout)
else:
bench_fn()

fn()
torch.accelerator.synchronize()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Replace nonexistent torch.accelerator.synchronize call

The new benchmarking routine now calls torch.accelerator.synchronize() right after the warm‑up call. PyTorch does not expose a torch.accelerator module, so do_bench will raise AttributeError before any timing runs (even on CUDA where the previous code used torch.cuda.synchronize()). This regression makes the default benchmarking utility unusable until the correct backend-specific synchronize (e.g., torch.cuda.synchronize() or an explicit branch for MPS) is restored.

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants