-
Notifications
You must be signed in to change notification settings - Fork 276
[Feature] Tilelang JITv2: Low Overhead and Syntax Sugars #1003
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughAdds a dtype interop module and normalizes dtypes across language APIs; introduces TileLang v2 (AST rewriter, DSL builder, compile/JIT pipeline, tensor/schema types) with aggregated re-exports; updates many TIR intrinsics to normalize dtype strings; and adds a signal-based timeout to the benchmarking utility and changed do_bench signature/defaults. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant PyFunc as User Function
participant AST as v2.ast_rewrite
participant Builder as v2.compile.DSLBuilder
participant JIT as v2.jit.JITDispatcher
participant Compiler as v2.jit.compile
participant Kernel as JITKernel
User->>JIT: @jit(...)(PyFunc)
JIT->>AST: quote/DSLMutator(PyFunc)
AST-->>Builder: Transformed AST
Builder->>Builder: build IR / prim_func
JIT->>Compiler: compile(JITFunc)
Compiler-->>Kernel: produce JITKernel (lib, wrapper)
User->>Kernel: call(args)
Kernel->>Kernel: dispatch to runtime wrapper
Kernel-->>User: results
sequenceDiagram
autonumber
actor Caller
participant Bench as profiler.bench.do_bench
participant Runner as bench_fn
participant Signal as signal.alarm
Caller->>Bench: do_bench(bench_fn, timeout=T)
alt timeout specified
Bench->>Signal: set alarm(T)
Bench->>Runner: run_with_timeout(Runner)
Runner-->>Bench: result or TimeoutException
Bench->>Signal: cancel alarm
else no timeout
Bench->>Runner: run()
end
Bench-->>Caller: latency or samples
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Suggested reviewers
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 14
🧹 Nitpick comments (9)
tilelang/profiler/bench.py (1)
101-110
: Update docstring for renamed arg & new defaultDocs still refer to
fn
and say the backend default is"event"
, but the signature now usesbench_fn
and defaults to"cupti"
. Please sync the docstring so callers aren’t misled.tilelang/language/allocate.py (2)
22-38
: Consistent dtype normalization — minor doc updateGood switch to AnyDType + get_tvm_dtype and the bool shared-scope hack. Please update docstrings to reflect dtype: AnyDType (not str) to avoid confusion.
97-125
: TMEM doc args mismatch with signatureDocstring refers to num_cols:int, but API takes shape and asserts 2D. Align docs (document expected 2D shape, column constraints) or change signature if you intended a scalar columns param.
tilelang/language/v2/ast_rewrite.py (1)
114-121
: Exception type nit in _parse_namesRuff suggests TypeError for invalid type. Optional tweak:
- raise SyntaxError("Unsupported for target") + raise TypeError("Unsupported target type in for-loop")tilelang/language/v2/lang.py (1)
198-206
: Optional: empty() dtype typingFor consistency with AnyDType across the codebase, consider annotating dtype as AnyDType.
-def empty( - shape: "Tuple[*_Shapes]", - dtype: torch.dtype | tvm.DataType, +def empty( + shape: "Tuple[*_Shapes]", + dtype: AnyDType,tilelang/language/v2/__init__.py (1)
1-33
: noqa noise and wildcard importNumerous noqa: F401 markers and a wildcard import trigger Ruff noise. Prefer explicit re-exports (all) and drop redundant noqas if F401 isn’t enabled in CI.
-from .v1 import * # noqa: F401 +# Prefer: from .v1 import __all__ as _v1_all; from .v1 import *; __all__ = [*locals()['__all__'], *_v1_all] +# Or explicitly enumerate what you intend to re-export.tilelang/language/v2/jit.py (3)
398-404
: AutoTuner.run: unused dyn_argsAvoid unused binding.
- const_args, dyn_args = self.arg_parser(*cfg.args, **cfg.kwargs) # type: ignore + const_args, _ = self.arg_parser(*cfg.args, **cfg.kwargs) # type: ignore
553-559
: Unreachable code after raise in partial()Code after raise is dead. Either implement auto-tune fallback or remove the unreachable lines.
- if has_tune(const_args): - raise NotImplementedError("Please manually use ker.tune to run autotuner") - result = self.tune(*args, **kws) - assert not has_tune(result.best_args.args) - assert not has_tune(result.best_args.kwargs.values()) - return self.partial(*result.best_args.args, **result.best_args.kwargs) + if has_tune(const_args): + raise NotImplementedError("Please manually use ker.tune to run autotuner")
700-719
: Optional types in overloadsType hints use bare None defaults; annotate as Optional[...] to satisfy strict type checkers.
- target_host: Union[str, Target] = None, + target_host: Optional[Union[str, Target]] = None, @@ - pass_configs: Dict[str, Any] = None, - compile_flags: List[str] = None, + pass_configs: Optional[Dict[str, Any]] = None, + compile_flags: Optional[List[str]] = None,
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
tilelang/language/__init__.py
(2 hunks)tilelang/language/allocate.py
(8 hunks)tilelang/language/customize.py
(3 hunks)tilelang/language/dtypes.py
(1 hunks)tilelang/language/proxy.py
(2 hunks)tilelang/language/tir/op.py
(28 hunks)tilelang/language/v2/__init__.py
(1 hunks)tilelang/language/v2/ast_rewrite.py
(1 hunks)tilelang/language/v2/compile.py
(1 hunks)tilelang/language/v2/jit.py
(1 hunks)tilelang/language/v2/lang.py
(1 hunks)tilelang/language/v2/v1.py
(1 hunks)tilelang/profiler/bench.py
(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (13)
tilelang/language/customize.py (1)
tilelang/language/dtypes.py (1)
get_tvm_dtype
(112-121)
tilelang/language/tir/op.py (1)
tilelang/language/dtypes.py (1)
get_tvm_dtype_str
(124-127)
tilelang/language/v2/ast_rewrite.py (2)
tilelang/language/ast/ir.py (2)
If
(1096-1112)Assert
(859-877)tilelang/language/v2/compile.py (2)
ctx
(706-710)arg
(540-568)
tilelang/language/v2/v1.py (18)
tilelang/language/__init__.py (5)
annotate_layout
(114-152)annotate_padding
(155-189)annotate_l2_hit_ratio
(192-206)import_source
(209-211)use_swizzle
(105-111)tilelang/language/parallel.py (1)
Parallel
(8-28)tilelang/language/pipeline.py (1)
Pipelined
(9-46)tilelang/language/persistent.py (1)
Persistent
(8-27)tilelang/language/frame.py (2)
has_let_value
(189-198)get_let_value
(201-210)tilelang/language/kernel.py (1)
KernelLaunchFrame
(95-226)tilelang/language/allocate.py (5)
alloc_local
(41-53)alloc_shared
(22-38)alloc_fragment
(56-68)alloc_barrier
(85-94)alloc_reducer
(127-162)tilelang/language/copy.py (2)
copy
(10-86)c2d_im2col
(89-120)tilelang/language/gemm.py (1)
gemm_v2
(216-428)tilelang/language/experimental/gemm_sp.py (1)
gemm_sp
(9-86)tilelang/language/fill.py (2)
fill
(9-21)clear
(24-48)tilelang/language/reduce.py (6)
reduce_max
(50-68)reduce_min
(71-84)reduce_sum
(87-109)reduce_abssum
(112-124)reduce_absmax
(127-139)finalize_reducer
(210-227)tilelang/language/customize.py (3)
dp4a
(10-21)clamp
(24-37)view
(53-66)tilelang/language/atomic.py (2)
atomic_load
(307-343)atomic_store
(346-396)tilelang/language/logical.py (2)
any_of
(10-42)all_of
(45-77)tilelang/language/utils.py (1)
index_to_coordinates
(91-110)tilelang/language/dtypes.py (4)
get_cffi_dtype
(143-145)get_ctypes_dtype
(138-140)get_tvm_dtype
(112-121)get_tvm_ptr_type
(148-151)tilelang/language/proxy.py (1)
make_tensor
(304-309)
tilelang/profiler/bench.py (2)
tilelang/autotuner/tuner.py (4)
TimeoutException
(36-37)timeout_handler
(40-41)run_with_timeout
(44-53)func
(330-334)tilelang/profiler/__init__.py (2)
func
(284-286)do_bench
(218-281)
tilelang/language/__init__.py (1)
tilelang/language/dtypes.py (1)
get_tvm_dtype
(112-121)
tilelang/language/dtypes.py (1)
tilelang/language/v2/lang.py (1)
dtype
(111-112)
tilelang/language/proxy.py (2)
tilelang/language/dtypes.py (1)
get_tvm_dtype
(112-121)tilelang/language/ast/ir.py (1)
handle
(1467-1497)
tilelang/language/allocate.py (2)
tilelang/language/dtypes.py (1)
get_tvm_dtype
(112-121)tilelang/language/ast/ir.py (1)
alloc_buffer
(441-508)
tilelang/language/v2/__init__.py (4)
tilelang/language/v2/lang.py (13)
empty_data_ptr
(248-251)DynSchema
(23-31)ConstSchema
(35-38)TensorSchema
(53-65)StridedTensorSchema
(42-50)tune
(239-245)Tune
(212-222)dyn
(142-143)StridedTensor
(169-170)Tensor
(172-173)empty
(198-208)MakeEmpty
(191-195)place
(268-281)tilelang/language/v2/jit.py (10)
tune
(510-536)compile
(178-179)compile
(562-578)jit
(696-697)jit
(701-708)jit
(711-739)JITDispatcher
(437-673)par_compile
(211-221)par_compile
(614-646)macro
(691-692)tilelang/language/proxy.py (3)
ptr
(277-301)StridedTensor
(258-259)Tensor
(255-256)tilelang/language/v2/compile.py (10)
set_pass_configs
(820-821)get_pass_configs
(476-477)get_pass_configs
(824-825)set_compile_flags
(828-829)add_compile_flags
(832-833)get_compile_flags
(479-480)get_compile_flags
(836-837)get_params
(840-841)JITFunc
(214-341)macro
(454-459)
tilelang/language/v2/compile.py (6)
tilelang/language/kernel.py (1)
KernelLaunchFrame
(95-226)tilelang/language/v2/ast_rewrite.py (1)
DSLMutator
(87-310)tilelang/language/v2/lang.py (17)
DynSchema
(23-31)StridedTensorSchema
(42-50)ConstSchema
(35-38)MakeEmpty
(191-195)Place
(255-265)_param
(75-76)TensorV2
(94-137)name
(99-100)shape
(103-104)stride
(264-265)dtype
(111-112)strides
(107-108)params
(121-122)params
(130-131)params
(161-162)empty
(198-208)Tensor
(172-173)tilelang/language/dtypes.py (5)
get_tvm_dtype
(112-121)get_torch_dtype
(130-135)get_tvm_ptr_type
(148-151)get_cffi_dtype
(143-145)get_ctypes_dtype
(138-140)tilelang/transform/pass_config.py (1)
PassConfigKey
(6-104)tilelang/language/ast/ir.py (8)
target
(1682-1713)decl_buffer
(1137-1205)let
(921-956)evaluate
(1319-1331)buffer_store
(1263-1300)If
(1096-1112)Then
(1115-1123)Else
(1126-1134)
tilelang/language/v2/jit.py (4)
tilelang/jit/adapter/cython/adapter.py (1)
CythonKernelAdapter
(179-526)tilelang/language/v2/compile.py (19)
make_prim_func_generator
(795-805)generate_arg_parser
(76-174)DSLBuilder
(422-764)JITFunc
(214-341)JITPyFunc
(44-49)JITArgParser
(52-59)parse_args
(322-323)prim_func
(501-502)out_idx
(229-230)get_cffi_sig
(255-264)generate_global_alloc_wrapper
(266-320)arg
(540-568)get_pass_configs
(476-477)get_pass_configs
(824-825)get_compile_flags
(479-480)get_compile_flags
(836-837)get
(473-474)get_global_allocs
(467-468)get_outs
(470-471)tilelang/utils/target.py (1)
determine_target
(54-99)tilelang/language/v2/lang.py (9)
Tune
(212-222)TuneMany
(226-236)Place
(255-265)Tensor
(172-173)shape
(103-104)dtype
(111-112)strides
(107-108)stride
(264-265)name
(99-100)
tilelang/language/v2/lang.py (3)
tilelang/language/dtypes.py (1)
VoidPtr
(8-9)tilelang/language/proxy.py (3)
StridedTensor
(258-259)Tensor
(255-256)ptr
(277-301)tilelang/language/v2/jit.py (1)
tune
(510-536)
🪛 Ruff (0.13.3)
tilelang/language/v2/ast_rewrite.py
120-120: Prefer TypeError
exception for invalid type
(TRY004)
120-120: Avoid specifying long messages outside the exception class
(TRY003)
tilelang/language/v2/v1.py
2-2: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
3-3: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
4-4: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
5-5: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
6-6: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
8-8: from tilelang.language.tir.ir import *
used; unable to detect undefined names
(F403)
8-8: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
9-9: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
10-10: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
11-11: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
12-12: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
13-13: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
15-15: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
16-16: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
17-17: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
18-18: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
19-19: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
20-20: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
22-22: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
24-24: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
25-25: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
26-26: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
27-27: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
28-28: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
29-29: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
31-31: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
32-32: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
33-33: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
34-34: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
36-36: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
37-37: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
38-38: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
39-39: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
40-40: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
41-41: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
42-42: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
43-43: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
45-45: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
47-47: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
48-48: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
49-49: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
50-50: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
51-51: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
52-52: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
53-53: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
54-54: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
55-55: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
56-56: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
57-57: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
59-59: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
60-60: from tilelang.language.builtin import *
used; unable to detect undefined names
(F403)
60-60: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
61-61: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
63-63: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
64-64: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
65-65: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
66-66: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
67-67: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
68-68: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
70-70: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
tilelang/profiler/bench.py
15-15: Unused function argument: signum
(ARG001)
15-15: Unused function argument: frame
(ARG001)
16-16: Avoid specifying long messages outside the exception class
(TRY003)
24-25: Remove exception handler; error is immediately re-raised
(TRY203)
25-25: Use raise
without specifying exception name
Remove exception name
(TRY201)
tilelang/language/__init__.py
84-84: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
85-85: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
86-86: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
tilelang/language/proxy.py
307-307: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
tilelang/language/v2/__init__.py
2-2: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
3-3: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
4-4: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
5-5: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
6-6: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
7-7: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
8-8: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
9-9: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
10-10: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
11-11: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
12-12: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
13-13: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
14-14: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
15-15: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
18-18: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
19-19: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
20-20: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
21-21: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
22-22: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
23-23: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
26-26: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
27-27: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
28-28: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
29-29: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
30-30: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
31-31: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
33-33: from .v1 import *
used; unable to detect undefined names
(F403)
33-33: Unused noqa
directive (non-enabled: F401
)
Remove unused noqa
directive
(RUF100)
tilelang/language/v2/compile.py
168-168: Use of exec
detected
(S102)
203-203: Avoid specifying long messages outside the exception class
(TRY003)
236-236: Avoid specifying long messages outside the exception class
(TRY003)
251-251: Prefer TypeError
exception for invalid type
(TRY004)
251-251: Avoid specifying long messages outside the exception class
(TRY003)
262-262: Prefer TypeError
exception for invalid type
(TRY004)
262-262: Avoid specifying long messages outside the exception class
(TRY003)
292-292: Prefer TypeError
exception for invalid type
(TRY004)
292-292: Avoid specifying long messages outside the exception class
(TRY003)
317-317: Use of exec
detected
(S102)
329-329: Use explicit conversion flag
Replace with conversion flag
(RUF010)
330-330: Use explicit conversion flag
Replace with conversion flag
(RUF010)
331-331: Use explicit conversion flag
Replace with conversion flag
(RUF010)
332-332: Use explicit conversion flag
Replace with conversion flag
(RUF010)
333-333: Use explicit conversion flag
Replace with conversion flag
(RUF010)
334-334: Use explicit conversion flag
Replace with conversion flag
(RUF010)
335-335: Use explicit conversion flag
Replace with conversion flag
(RUF010)
336-336: Use explicit conversion flag
Replace with conversion flag
(RUF010)
418-418: Avoid specifying long messages outside the exception class
(TRY003)
563-563: Unpacked variable arg
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
570-570: Unused method argument: annot
(ARG002)
578-579: Parenthesize a and b
expressions when chaining and
and or
together, to make the precedence clear
Parenthesize the and
subexpression
(RUF021)
579-580: Parenthesize a and b
expressions when chaining and
and or
together, to make the precedence clear
Parenthesize the and
subexpression
(RUF021)
588-590: Avoid specifying long messages outside the exception class
(TRY003)
597-598: Avoid specifying long messages outside the exception class
(TRY003)
637-637: Avoid specifying long messages outside the exception class
(TRY003)
750-750: Avoid specifying long messages outside the exception class
(TRY003)
755-755: Avoid specifying long messages outside the exception class
(TRY003)
788-788: Use of exec
detected
(S102)
tilelang/language/v2/jit.py
68-68: Use explicit conversion flag
Replace with conversion flag
(RUF010)
69-69: Use explicit conversion flag
Replace with conversion flag
(RUF010)
71-71: Use explicit conversion flag
Replace with conversion flag
(RUF010)
72-72: Use explicit conversion flag
Replace with conversion flag
(RUF010)
114-114: Avoid specifying long messages outside the exception class
(TRY003)
164-164: Avoid specifying long messages outside the exception class
(TRY003)
193-193: Use explicit conversion flag
Replace with conversion flag
(RUF010)
195-195: Use raise
without specifying exception name
Remove exception name
(TRY201)
215-215: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
277-277: Use explicit conversion flag
Replace with conversion flag
(RUF010)
280-280: Use explicit conversion flag
Replace with conversion flag
(RUF010)
280-280: Use explicit conversion flag
Replace with conversion flag
(RUF010)
287-287: Use explicit conversion flag
Replace with conversion flag
(RUF010)
291-291: Use explicit conversion flag
Replace with conversion flag
(RUF010)
291-291: Use explicit conversion flag
Replace with conversion flag
(RUF010)
293-293: Use explicit conversion flag
Replace with conversion flag
(RUF010)
326-326: Use explicit conversion flag
Replace with conversion flag
(RUF010)
339-339: Do not catch blind exception: Exception
(BLE001)
366-366: Use raise
without specifying exception name
Remove exception name
(TRY201)
368-368: Use explicit conversion flag
Replace with conversion flag
(RUF010)
400-400: Unpacked variable dyn_args
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
423-423: Do not catch blind exception: Exception
(BLE001)
426-426: Use explicit conversion flag
Replace with conversion flag
(RUF010)
444-444: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
445-445: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
616-616: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
618-618: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
636-636: Use explicit conversion flag
Replace with conversion flag
(RUF010)
638-638: Use raise
without specifying exception name
Remove exception name
(TRY201)
705-705: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
706-706: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
715-715: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
716-716: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
tilelang/language/v2/lang.py
50-50: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-amd
- GitHub Check: build-test-metal
def get_tvm_dtype_str(ty: AnyDType) -> str: | ||
if isinstance(ty, str): | ||
return ty | ||
return _dtype_torch2tvm[ty] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle tvm.DataType
(and other AnyDType cases) before mapping.
get_tvm_dtype_str
currently only returns early for str
inputs. For every other allowed AnyDType
(e.g., the tvm.DataType
objects produced by get_tvm_dtype
, ir.PrimType
, or the new VoidPtr
sentinel), this code immediately indexes _dtype_torch2tvm
. Passing any of those values now raises KeyError
, which breaks call sites such as call_intrin(get_tvm_dtype(torch.float32), ...)
and any of the intrinsic wrappers updated in this PR. Please normalize non-string inputs first—for example, accept tvm.DataType
/ir.PrimType
by returning str(ty)
(or ty.dtype
), detect VoidPtr
, and only fall back to _dtype_torch2tvm
for Python/torch dtypes.
A minimal patch might look like:
def get_tvm_dtype_str(ty: AnyDType) -> str:
- if isinstance(ty, str):
- return ty
- return _dtype_torch2tvm[ty]
+ if isinstance(ty, str):
+ return ty
+ if isinstance(ty, tvm.DataType):
+ return str(ty)
+ if isinstance(ty, ir.PrimType):
+ return str(ty)
+ if isinstance(ty, VoidPtr) or ty == VoidPtr:
+ return "handle"
+ return _dtype_torch2tvm[ty]
(Feel free to tailor the exact pointer handling, but the key is to cover all members of AnyDType
without throwing.)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def get_tvm_dtype_str(ty: AnyDType) -> str: | |
if isinstance(ty, str): | |
return ty | |
return _dtype_torch2tvm[ty] | |
def get_tvm_dtype_str(ty: AnyDType) -> str: | |
if isinstance(ty, str): | |
return ty | |
if isinstance(ty, tvm.DataType): | |
return str(ty) | |
if isinstance(ty, ir.PrimType): | |
return str(ty) | |
if isinstance(ty, VoidPtr) or ty == VoidPtr: | |
return "handle" | |
return _dtype_torch2tvm[ty] |
🤖 Prompt for AI Agents
In tilelang/language/dtypes.py around lines 124 to 127, get_tvm_dtype_str
currently only special-cases str and then indexes _dtype_torch2tvm, which raises
KeyError for tvm.DataType, ir.PrimType, VoidPtr, and other AnyDType variants;
change it to first normalize non-string inputs: if ty is tvm.DataType or
ir.PrimType (or has a .dtype attr) return str(ty) or ty.dtype as appropriate,
detect the VoidPtr sentinel and return the proper pointer/type string, and only
if ty is a Python/torch dtype fall back to looking up _dtype_torch2tvm; this
prevents KeyError and covers all AnyDType members before mapping.
def ast_get_span(ast: ast.AST) -> Tuple[int, int, int, int]: | ||
if not ast_has_span(ast): | ||
return None | ||
return tuple(getattr(ast, attr) for attr in _span_attrs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ast_get_span returns None but annotated non-Optional
The function can return None; fix the return type.
-def ast_get_span(ast: ast.AST) -> Tuple[int, int, int, int]:
+def ast_get_span(ast: ast.AST) -> Optional[Tuple[int, int, int, int]]:
if not ast_has_span(ast):
return None
return tuple(getattr(ast, attr) for attr in _span_attrs)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def ast_get_span(ast: ast.AST) -> Tuple[int, int, int, int]: | |
if not ast_has_span(ast): | |
return None | |
return tuple(getattr(ast, attr) for attr in _span_attrs) | |
def ast_get_span(ast: ast.AST) -> Optional[Tuple[int, int, int, int]]: | |
if not ast_has_span(ast): | |
return None | |
return tuple(getattr(ast, attr) for attr in _span_attrs) |
🤖 Prompt for AI Agents
In tilelang/language/v2/ast_rewrite.py around lines 11 to 15, ast_get_span is
annotated to return Tuple[int, int, int, int] but can return None; change the
return type to Optional[Tuple[int, int, int, int]] and add the appropriate
import from typing (Optional) at the top of the file. Update the function
signature to use Optional[...] and leave the implementation as-is; also scan
callers for any typing assumptions and handle the Optional return where
necessary.
def ast_set_span(ast: ast.AST, span: Tuple[int, int, int, int]): | ||
if not ast_has_span(ast): | ||
return | ||
for attr, value in zip(_span_attrs, span): | ||
setattr(ast, attr, value) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ast_set_span never sets on fresh nodes
Guard prevents spans from being set on freshly-built AST nodes (hasattr is False initially). Always set these attributes.
-def ast_set_span(ast: ast.AST, span: Tuple[int, int, int, int]):
- if not ast_has_span(ast):
- return
- for attr, value in zip(_span_attrs, span):
- setattr(ast, attr, value)
+def ast_set_span(ast: ast.AST, span: Tuple[int, int, int, int]):
+ for attr, value in zip(_span_attrs, span):
+ setattr(ast, attr, value)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def ast_set_span(ast: ast.AST, span: Tuple[int, int, int, int]): | |
if not ast_has_span(ast): | |
return | |
for attr, value in zip(_span_attrs, span): | |
setattr(ast, attr, value) | |
def ast_set_span(ast: ast.AST, span: Tuple[int, int, int, int]): | |
for attr, value in zip(_span_attrs, span): | |
setattr(ast, attr, value) |
🤖 Prompt for AI Agents
In tilelang/language/v2/ast_rewrite.py around lines 17 to 22, the current
early-return prevents spans from being applied to newly created AST nodes;
remove the guard and always set the span attributes. Replace the body so it
simply iterates over _span_attrs and the provided span tuple and calls
setattr(ast, attr, value) for each pair (no hasattr check), ensuring
freshly-built nodes get their span fields populated.
def quote_expr(expr: str, **kws) -> List[ast.AST]: | ||
res = quote1(expr, **kws) | ||
assert isinstance(res, ast.Expr) | ||
return res.value | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
quote_expr return type
Returns an expression node, not a list.
-def quote_expr(expr: str, **kws) -> List[ast.AST]:
+def quote_expr(expr: str, **kws) -> ast.AST:
res = quote1(expr, **kws)
assert isinstance(res, ast.Expr)
return res.value
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def quote_expr(expr: str, **kws) -> List[ast.AST]: | |
res = quote1(expr, **kws) | |
assert isinstance(res, ast.Expr) | |
return res.value | |
def quote_expr(expr: str, **kws) -> ast.AST: | |
res = quote1(expr, **kws) | |
assert isinstance(res, ast.Expr) | |
return res.value |
🤖 Prompt for AI Agents
In tilelang/language/v2/ast_rewrite.py around lines 61 to 65, the function
quote_expr is annotated to return List[ast.AST] but actually returns a single
expression node (res.value); update the function signature to return ast.AST
(not List[ast.AST]) and adjust any imports/typing references accordingly, and
scan/adjust callers if they expect a list so they handle a single ast.AST
instead.
all_args = node.args.posonlyargs + node.args.args | ||
if node.args.vararg is not None: | ||
all_args += node.args.vararg | ||
all_args += node.args.kwonlyargs | ||
stmts = [] | ||
for arg in all_args: | ||
name = arg.arg | ||
if arg.annotation is not None: | ||
arg_stmt = quote1( | ||
f'{name} = __tb.arg("{name}", {name}, annot)', annot=arg.annotation, span=arg) | ||
else: | ||
arg_stmt = quote1(f'{name} = __tb.arg("{name}", {name})', span=arg) | ||
stmts.append(arg_stmt) | ||
node.decorator_list.pop(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FunctionDef vararg handling bug (TypeError at runtime)
all_args += node.args.vararg tries to add ast.arg to list. Use append.
- if node.args.vararg is not None:
- all_args += node.args.vararg
+ if node.args.vararg is not None:
+ all_args.append(node.args.vararg)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
all_args = node.args.posonlyargs + node.args.args | |
if node.args.vararg is not None: | |
all_args += node.args.vararg | |
all_args += node.args.kwonlyargs | |
stmts = [] | |
for arg in all_args: | |
name = arg.arg | |
if arg.annotation is not None: | |
arg_stmt = quote1( | |
f'{name} = __tb.arg("{name}", {name}, annot)', annot=arg.annotation, span=arg) | |
else: | |
arg_stmt = quote1(f'{name} = __tb.arg("{name}", {name})', span=arg) | |
stmts.append(arg_stmt) | |
node.decorator_list.pop(0) | |
all_args = node.args.posonlyargs + node.args.args | |
if node.args.vararg is not None: | |
all_args.append(node.args.vararg) | |
all_args += node.args.kwonlyargs | |
stmts = [] | |
for arg in all_args: | |
name = arg.arg | |
if arg.annotation is not None: | |
arg_stmt = quote1( | |
f'{name} = __tb.arg("{name}", {name}, annot)', annot=arg.annotation, span=arg) | |
else: | |
arg_stmt = quote1(f'{name} = __tb.arg("{name}", {name})', span=arg) | |
stmts.append(arg_stmt) | |
node.decorator_list.pop(0) |
🤖 Prompt for AI Agents
In tilelang/language/v2/ast_rewrite.py around lines 231 to 244, the code uses
all_args += node.args.vararg which attempts to concatenate an ast.arg object to
a list and raises a TypeError at runtime; change that to
all_args.append(node.args.vararg) (keeping the existing "if node.args.vararg is
not None" guard) so the vararg ast.arg is added as a single element to the
all_args list.
def params(self): | ||
return *self.shape_params(), self.dtype | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Syntax error: starred return in params()
return *... is invalid in Python. Wrap in a tuple literal.
- def params(self):
- return *self.shape_params(), self.dtype
+ def params(self):
+ return (*self.shape_params(), self.dtype)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def params(self): | |
return *self.shape_params(), self.dtype | |
def params(self): | |
return (*self.shape_params(), self.dtype) |
🤖 Prompt for AI Agents
In tilelang/language/v2/lang.py around lines 130 to 132, the params() method
currently uses an invalid starred return (return *self.shape_params(),
self.dtype); change it to return a tuple by wrapping the expanded values in
parentheses — e.g. return (*self.shape_params(), self.dtype) — so the shape
parameters are expanded into the returned tuple along with dtype.
def __ne__(self, rhs) -> bool: | ||
return isinstance(rhs, Tune) and self.data == rhs.data | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tune.ne implements equality, not inequality
ne should be logical negation of eq.
- def __ne__(self, rhs) -> bool:
- return isinstance(rhs, Tune) and self.data == rhs.data
+ def __ne__(self, rhs) -> bool:
+ return not self.__eq__(rhs)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def __ne__(self, rhs) -> bool: | |
return isinstance(rhs, Tune) and self.data == rhs.data | |
def __ne__(self, rhs) -> bool: | |
return not self.__eq__(rhs) |
🤖 Prompt for AI Agents
In tilelang/language/v2/lang.py around lines 221 to 223, __ne__ currently
returns the same result as equality; change it to be the logical negation of
__eq__. Implement __ne__ to return not self.__eq__(rhs) (or equivalently return
not (isinstance(rhs, Tune) and self.data == rhs.data)) so inequality correctly
mirrors equality.
if device is None: | ||
device = torch.cuda.current_device() | ||
return Place(shape=shape, strides=strides, dtype=dtype, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
place(): device type mismatch (int vs torch.device)
torch.cuda.current_device() returns int; Place.device expects torch.device and is compared against torch.device in arg parsing. Use torch.device("cuda") for consistency.
- if device is None:
- device = torch.cuda.current_device()
+ if device is None:
+ device = torch.device("cuda")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if device is None: | |
device = torch.cuda.current_device() | |
return Place(shape=shape, strides=strides, dtype=dtype, device=device) | |
if device is None: | |
device = torch.device("cuda") | |
return Place(shape=shape, strides=strides, dtype=dtype, device=device) |
🤖 Prompt for AI Agents
In tilelang/language/v2/lang.py around lines 279 to 281, the code sets device =
torch.cuda.current_device() which returns an int while Place.device expects a
torch.device; change the fallback to construct a torch.device (e.g. device =
torch.device("cuda") or torch.device(f"cuda:{torch.cuda.current_device()}") if
you need the specific GPU index) so Place receives a torch.device instance and
comparisons in arg parsing work correctly.
def _apply_tvm_patches(): | ||
|
||
def __array_eq(self, rhs): | ||
if isinstance(rhs, tuple): | ||
return tuple(self) == rhs | ||
if isinstance(rhs, list): | ||
return list(self) == rhs | ||
if isinstance(rhs, tvm.ffi.container.Array): | ||
return tvm.core.Object.__eq__(self, rhs) | ||
|
||
def __array_ne(self, rhs): | ||
if isinstance(rhs, tuple): | ||
return tuple(self) != rhs | ||
if isinstance(rhs, list): | ||
return list(self) != rhs | ||
if isinstance(rhs, tvm.ffi.container.Array): | ||
return tvm.core.Object.__ne__(self, rhs) | ||
|
||
def __array_req(self, lhs): | ||
if isinstance(lhs, tuple): | ||
return tuple(self) == lhs | ||
if isinstance(lhs, list): | ||
return list(self) == lhs | ||
if isinstance(lhs, tvm.ffi.container.Array): | ||
return tvm.core.Object.__eq__(self, lhs) | ||
|
||
def __array_rne(self, rhs): | ||
if isinstance(rhs, tuple): | ||
return tuple(self) != rhs | ||
if isinstance(rhs, list): | ||
return list(self) != rhs | ||
if isinstance(rhs, tvm.ffi.container.Array): | ||
return tvm.core.Object.__ne__(self, rhs) | ||
|
||
tvm.ffi.container.Array.__eq__ = __array_eq | ||
tvm.ffi.container.Array.__ne__ = __array_ne | ||
tvm.ffi.container.Array.__req__ = __array_req | ||
tvm.ffi.container.Array.__rne__ = __array_rne | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Custom tvm Array equality should return NotImplemented by default
If rhs/lhs is an unsupported type, return NotImplemented instead of falling through (None). Prevents surprising behavior in equality.
def __array_eq(self, rhs):
if isinstance(rhs, tuple):
return tuple(self) == rhs
if isinstance(rhs, list):
return list(self) == rhs
if isinstance(rhs, tvm.ffi.container.Array):
return tvm.core.Object.__eq__(self, rhs)
+ return NotImplemented
def __array_ne(self, rhs):
if isinstance(rhs, tuple):
return tuple(self) != rhs
if isinstance(rhs, list):
return list(self) != rhs
if isinstance(rhs, tvm.ffi.container.Array):
return tvm.core.Object.__ne__(self, rhs)
+ return NotImplemented
def __array_req(self, lhs):
if isinstance(lhs, tuple):
return tuple(self) == lhs
if isinstance(lhs, list):
return list(self) == lhs
if isinstance(lhs, tvm.ffi.container.Array):
return tvm.core.Object.__eq__(self, lhs)
+ return NotImplemented
def __array_rne(self, rhs):
if isinstance(rhs, tuple):
return tuple(self) != rhs
if isinstance(rhs, list):
return list(self) != rhs
if isinstance(rhs, tvm.ffi.container.Array):
return tvm.core.Object.__ne__(self, rhs)
+ return NotImplemented
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def _apply_tvm_patches(): | |
def __array_eq(self, rhs): | |
if isinstance(rhs, tuple): | |
return tuple(self) == rhs | |
if isinstance(rhs, list): | |
return list(self) == rhs | |
if isinstance(rhs, tvm.ffi.container.Array): | |
return tvm.core.Object.__eq__(self, rhs) | |
def __array_ne(self, rhs): | |
if isinstance(rhs, tuple): | |
return tuple(self) != rhs | |
if isinstance(rhs, list): | |
return list(self) != rhs | |
if isinstance(rhs, tvm.ffi.container.Array): | |
return tvm.core.Object.__ne__(self, rhs) | |
def __array_req(self, lhs): | |
if isinstance(lhs, tuple): | |
return tuple(self) == lhs | |
if isinstance(lhs, list): | |
return list(self) == lhs | |
if isinstance(lhs, tvm.ffi.container.Array): | |
return tvm.core.Object.__eq__(self, lhs) | |
def __array_rne(self, rhs): | |
if isinstance(rhs, tuple): | |
return tuple(self) != rhs | |
if isinstance(rhs, list): | |
return list(self) != rhs | |
if isinstance(rhs, tvm.ffi.container.Array): | |
return tvm.core.Object.__ne__(self, rhs) | |
tvm.ffi.container.Array.__eq__ = __array_eq | |
tvm.ffi.container.Array.__ne__ = __array_ne | |
tvm.ffi.container.Array.__req__ = __array_req | |
tvm.ffi.container.Array.__rne__ = __array_rne | |
def _apply_tvm_patches(): | |
def __array_eq(self, rhs): | |
if isinstance(rhs, tuple): | |
return tuple(self) == rhs | |
if isinstance(rhs, list): | |
return list(self) == rhs | |
if isinstance(rhs, tvm.ffi.container.Array): | |
return tvm.core.Object.__eq__(self, rhs) | |
return NotImplemented | |
def __array_ne(self, rhs): | |
if isinstance(rhs, tuple): | |
return tuple(self) != rhs | |
if isinstance(rhs, list): | |
return list(self) != rhs | |
if isinstance(rhs, tvm.ffi.container.Array): | |
return tvm.core.Object.__ne__(self, rhs) | |
return NotImplemented | |
def __array_req(self, lhs): | |
if isinstance(lhs, tuple): | |
return tuple(self) == lhs | |
if isinstance(lhs, list): | |
return list(self) == lhs | |
if isinstance(lhs, tvm.ffi.container.Array): | |
return tvm.core.Object.__eq__(self, lhs) | |
return NotImplemented | |
def __array_rne(self, rhs): | |
if isinstance(rhs, tuple): | |
return tuple(self) != rhs | |
if isinstance(rhs, list): | |
return list(self) != rhs | |
if isinstance(rhs, tvm.ffi.container.Array): | |
return tvm.core.Object.__ne__(self, rhs) | |
return NotImplemented | |
tvm.ffi.container.Array.__eq__ = __array_eq | |
tvm.ffi.container.Array.__ne__ = __array_ne | |
tvm.ffi.container.Array.__req__ = __array_req | |
tvm.ffi.container.Array.__rne__ = __array_rne |
🤖 Prompt for AI Agents
In tilelang/language/v2/lang.py around lines 289-327, the custom tvm Array
equality and inequality helpers currently fall through and implicitly return
None for unsupported types; change each helper (__array_eq, __array_ne,
__array_req, __array_rne) so that after the existing isinstance checks they
explicitly return NotImplemented when rhs/lhs is an unsupported type, ensuring
Python's rich comparison machinery handles those cases correctly.
def run_with_timeout(func, timeout, *args, **kwargs): | ||
signal.signal(signal.SIGALRM, timeout_handler) | ||
signal.alarm(timeout) | ||
try: | ||
result = func(*args, **kwargs) | ||
except Exception as e: | ||
raise e | ||
finally: | ||
signal.alarm(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Timeout impl crashes on Windows
signal.SIGALRM
doesn’t exist on Windows, so any call to run_with_timeout
will raise AttributeError
. The new timeout
option therefore breaks outright on that platform. Guard the feature (hasattr(signal, "SIGALRM")
), fall back gracefully (or raise a clear OS error), and restore the previous handler after the alarm:
def run_with_timeout(func, timeout, *args, **kwargs):
- signal.signal(signal.SIGALRM, timeout_handler)
- signal.alarm(timeout)
+ if not hasattr(signal, "SIGALRM"):
+ raise RuntimeError("Timeout-based benchmarking is unsupported on this platform")
+ previous_handler = signal.getsignal(signal.SIGALRM)
+ signal.signal(signal.SIGALRM, timeout_handler)
+ signal.setitimer(signal.ITIMER_REAL, timeout)
try:
result = func(*args, **kwargs)
except Exception as e:
raise e
finally:
- signal.alarm(0)
+ signal.setitimer(signal.ITIMER_REAL, 0)
+ signal.signal(signal.SIGALRM, previous_handler)
return result
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def run_with_timeout(func, timeout, *args, **kwargs): | |
signal.signal(signal.SIGALRM, timeout_handler) | |
signal.alarm(timeout) | |
try: | |
result = func(*args, **kwargs) | |
except Exception as e: | |
raise e | |
finally: | |
signal.alarm(0) | |
def run_with_timeout(func, timeout, *args, **kwargs): | |
if not hasattr(signal, "SIGALRM"): | |
raise RuntimeError("Timeout-based benchmarking is unsupported on this platform") | |
previous_handler = signal.getsignal(signal.SIGALRM) | |
signal.signal(signal.SIGALRM, timeout_handler) | |
signal.setitimer(signal.ITIMER_REAL, timeout) | |
try: | |
result = func(*args, **kwargs) | |
except Exception as e: | |
raise e | |
finally: | |
signal.setitimer(signal.ITIMER_REAL, 0) | |
signal.signal(signal.SIGALRM, previous_handler) | |
return result |
🧰 Tools
🪛 Ruff (0.13.3)
24-25: Remove exception handler; error is immediately re-raised
(TRY203)
25-25: Use raise
without specifying exception name
Remove exception name
(TRY201)
🤖 Prompt for AI Agents
In tilelang/profiler/bench.py around lines 19 to 27, the timeout implementation
uses signal.SIGALRM which doesn't exist on Windows and will raise
AttributeError; update run_with_timeout to first check hasattr(signal,
"SIGALRM") and if absent either raise a clear OSError indicating timeouts are
unsupported on this platform or call the function without a timeout, and when
using SIGALRM save the previous signal handler before calling signal.signal, set
the alarm, wrap the function call in try/except/finally, cancel the alarm and
restore the previous handler in the finally block so the original signal state
is preserved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🧹 Nitpick comments (5)
examples/jitv2/jitv2.ipynb (5)
331-333
: Nit: Improve wording.Duplicate word and clarity.
- “use tl.StridedTensor to to use strided tensor” → “Use tl.StridedTensor for strided tensors”
- Consider clarifying which axes/fields are static vs dynamic.
343-343
: Rename get_contingous → get_contiguous (spelling).Improves readability and consistency with PyTorch terminology.
-def get_contingous( +def get_contiguous(-out = get_contingous(A[:, 0, :, 0]) +out = get_contiguous(A[:, 0, :, 0])Also applies to: 365-365
1787-1794
: Validate heads/groups consistency and use heads_kv.Add checks so grouped KV heads are coherent; also resolves the “unused heads_kv” static hint.
- batch, seq_len, heads, dim_qk, dtype = Q.params() - batch, seq_len, heads_kv, dim_qk, dtype = K.params() - batch, seq_len, heads_kv, dim_v, dtype = V.params() + batch, seq_len, heads, dim_qk, dtype = Q.params() + batch, seq_len, heads_kv, dim_qk, dtype = K.params() + batch, seq_len, heads_kv2, dim_v, dtype = V.params() + assert heads_kv == heads_kv2, "K/V must have same heads_kv" + assert heads % groups == 0, "heads must be divisible by groups" + assert heads_kv == heads // groups, "heads_kv must equal heads // groups"
738-749
: Fix typos in autotune section.Minor wording fixes for clarity.
- “Advanced Tunning” → “Advanced Tuning”
- “kerl.get_tune_configs”/“kerl.tune_configs” → “ker.get_tune_configs”/“ker.tune_configs” (or match the API you present, e.g., gemm.get_tune_configs / gemm.tune_configs)
1136-1137
: Clarify parallel compile wording.Use accurate name and phrasing.
- “Use ker.par_compile to compile many kernel args parallely” → “Use gemm.par_compile to compile many kernel args in parallel”
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/jitv2/jitv2.ipynb
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.14.0)
examples/jitv2/jitv2.ipynb
232-232: Found useless expression. Either assign it to a variable or remove it.
(B018)
373-373: Found useless expression. Either assign it to a variable or remove it.
(B018)
380-380: Found useless expression. Either assign it to a variable or remove it.
(B018)
388-388: Found useless expression. Either assign it to a variable or remove it.
(B018)
512-512: Unpacked variable heads_kv
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-metal
" N, K = A.shape_params()\n", | ||
" M, K = B.shape_params()\n", | ||
"\n", | ||
" C = tl.empty((M, N), dtype=accum_dtype)\n", | ||
" dims = [\n", | ||
" tl.ceildiv(M, block_M),\n", | ||
" tl.ceildiv(N, block_N),\n", | ||
" ]\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix GEMM shape inference (A/B dims swapped) and add inner-dim assert.
Current code treats A as (N, K) and B as (M, K), but indexing/gemm expects A=(M, K), B=(K, N). This can silently produce wrong results. Adjust extraction and assert K match.
- N, K = A.shape_params()
- M, K = B.shape_params()
+ M, K = A.shape_params()
+ K2, N = B.shape_params()
+ assert K == K2, "Incompatible inner dimensions: A.shape[1] != B.shape[0]"
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
" N, K = A.shape_params()\n", | |
" M, K = B.shape_params()\n", | |
"\n", | |
" C = tl.empty((M, N), dtype=accum_dtype)\n", | |
" dims = [\n", | |
" tl.ceildiv(M, block_M),\n", | |
" tl.ceildiv(N, block_N),\n", | |
" ]\n", | |
M, K = A.shape_params() | |
K2, N = B.shape_params() | |
assert K == K2, "Incompatible inner dimensions: A.shape[1] != B.shape[0]" | |
C = tl.empty((M, N), dtype=accum_dtype) | |
dims = [ | |
tl.ceildiv(M, block_M), | |
tl.ceildiv(N, block_N), | |
] |
🤖 Prompt for AI Agents
In examples/jitv2/jitv2.ipynb around lines 61 to 68, the code currently treats A
as (N,K) and B as (M,K) which is reversed for the GEMM implementation; update
the shape extraction to treat A as (M,K) and B as (K,N) (e.g., M,K =
A.shape_params(); K2,N = B.shape_params()), add an assertion that K == K2 to
validate the inner dimension matches, and keep C = tl.empty((M, N), ...) and the
dims computation unchanged so the output shape and tiling remain correct.
" A = tl.make_tensor(A_ptr, (M, K), dtype=dtype)\n", | ||
" B = tl.make_tensor(B_ptr, (N, K), dtype=dtype)\n", | ||
" dims = [\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct B tensor shape in pointer-based GEMM.
Indexing uses B[k, bx*block_N], so B must be (K, N), not (N, K).
- A = tl.make_tensor(A_ptr, (M, K), dtype=dtype)
- B = tl.make_tensor(B_ptr, (N, K), dtype=dtype)
+ A = tl.make_tensor(A_ptr, (M, K), dtype=dtype)
+ B = tl.make_tensor(B_ptr, (K, N), dtype=dtype)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
" A = tl.make_tensor(A_ptr, (M, K), dtype=dtype)\n", | |
" B = tl.make_tensor(B_ptr, (N, K), dtype=dtype)\n", | |
" dims = [\n", | |
A = tl.make_tensor(A_ptr, (M, K), dtype=dtype) | |
B = tl.make_tensor(B_ptr, (K, N), dtype=dtype) | |
dims = [ |
🤖 Prompt for AI Agents
In examples/jitv2/jitv2.ipynb around lines 174 to 176, the B tensor is created
with shape (N, K) but the kernel indexes B as B[k, bx*block_N], so B must be
shaped (K, N); change the tl.make_tensor call for B to use (K, N) as its shape
(and update any nearby variable names/comments if necessary to reflect the
corrected shape).
" with tl.Kernel(tl.ceildiv(N, block_N), threads=128) as bx:\n", | ||
" px = bx * block_N\n", | ||
" for i in tl.Parallel(block_N):\n", | ||
" C[px + i] = A[px + i] + B[px + i]\n", | ||
" return C" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard tail to avoid OOB in vec_add.
When N % block_N != 0, the last block accesses out-of-bounds.
- for i in tl.Parallel(block_N):
- C[px + i] = A[px + i] + B[px + i]
+ for i in tl.Parallel(block_N):
+ if px + i < N:
+ C[px + i] = A[px + i] + B[px + i]
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
" with tl.Kernel(tl.ceildiv(N, block_N), threads=128) as bx:\n", | |
" px = bx * block_N\n", | |
" for i in tl.Parallel(block_N):\n", | |
" C[px + i] = A[px + i] + B[px + i]\n", | |
" return C" | |
with tl.Kernel(tl.ceildiv(N, block_N), threads=128) as bx: | |
px = bx * block_N | |
for i in tl.Parallel(block_N): | |
if px + i < N: | |
C[px + i] = A[px + i] + B[px + i] | |
return C |
🤖 Prompt for AI Agents
In examples/jitv2/jitv2.ipynb around lines 252 to 256, the vec_add kernel
iterates full block_N across all blocks and will read/write past N when N %
block_N != 0; add a bounds guard for the tail: compute the global indices (px +
i) and mask them with a boolean valid = px + i < N, then use that mask to
conditionally load/store (or use tl.where/elementwise masking) so only in-bounds
lanes perform A+B and write to C; ensure the kernel returns C unchanged for
out-of-bounds lanes.
" with tl.Kernel(tl.ceildiv(N, block_N), threads=128) as bx:\n", | ||
" px = bx * block_N\n", | ||
" for i in tl.Parallel(block_N):\n", | ||
" C[px + i] = A[px + i] + cval\n", | ||
" return C" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard tail to avoid OOB in vec_add_scalar.
Same tail issue as vec_add.
- for i in tl.Parallel(block_N):
- C[px + i] = A[px + i] + cval
+ for i in tl.Parallel(block_N):
+ if px + i < N:
+ C[px + i] = A[px + i] + cval
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
" with tl.Kernel(tl.ceildiv(N, block_N), threads=128) as bx:\n", | |
" px = bx * block_N\n", | |
" for i in tl.Parallel(block_N):\n", | |
" C[px + i] = A[px + i] + cval\n", | |
" return C" | |
with tl.Kernel(tl.ceildiv(N, block_N), threads=128) as bx: | |
px = bx * block_N | |
for i in tl.Parallel(block_N): | |
if px + i < N: | |
C[px + i] = A[px + i] + cval | |
return C |
🤖 Prompt for AI Agents
In examples/jitv2/jitv2.ipynb around lines 297-301, the vec_add_scalar kernel
writes past the array end for the tail block; guard the tail by computing the
global index (idx = px + i) and only performing the write when idx < N (e.g., in
the loop add an if idx < N: C[idx] = A[idx] + cval or use a masked store),
ensuring out-of-bounds accesses are prevented.
"torch_beg = time.perf_counter()\n", | ||
"for _ in range(10000):\n", | ||
" A @ B\n", | ||
"torch_end = time.perf_counter()\n", | ||
"elapsed = (torch_end - torch_beg) / 10000 * 1e6\n", | ||
"\n", | ||
"print('Torch time: ', elapsed, 'us')\n", | ||
"\n", | ||
"tl_beg = time.perf_counter()\n", | ||
"for _ in range(10000):\n", | ||
" gemm(A, B)\n", | ||
"tl_end = time.perf_counter()\n", | ||
"elapsed = (tl_end - tl_beg) / 10000 * 1e6\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Synchronize CUDA for accurate timing.
GPU launches are async; without synchronize(), the timings are unreliable.
- torch_beg = time.perf_counter()
- for _ in range(10000):
- A @ B
- torch_end = time.perf_counter()
+ torch.cuda.synchronize()
+ torch_beg = time.perf_counter()
+ for _ in range(10000):
+ A @ B
+ torch.cuda.synchronize()
+ torch_end = time.perf_counter()
@@
- tl_beg = time.perf_counter()
- for _ in range(10000):
- gemm(A, B)
- tl_end = time.perf_counter()
+ torch.cuda.synchronize()
+ tl_beg = time.perf_counter()
+ for _ in range(10000):
+ gemm(A, B)
+ torch.cuda.synchronize()
+ tl_end = time.perf_counter()
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
"torch_beg = time.perf_counter()\n", | |
"for _ in range(10000):\n", | |
" A @ B\n", | |
"torch_end = time.perf_counter()\n", | |
"elapsed = (torch_end - torch_beg) / 10000 * 1e6\n", | |
"\n", | |
"print('Torch time: ', elapsed, 'us')\n", | |
"\n", | |
"tl_beg = time.perf_counter()\n", | |
"for _ in range(10000):\n", | |
" gemm(A, B)\n", | |
"tl_end = time.perf_counter()\n", | |
"elapsed = (tl_end - tl_beg) / 10000 * 1e6\n", | |
# Ensure any prior GPU work is finished | |
torch.cuda.synchronize() | |
torch_beg = time.perf_counter() | |
for _ in range(10000): | |
A @ B | |
# Wait for the last A @ B to complete | |
torch.cuda.synchronize() | |
torch_end = time.perf_counter() | |
elapsed = (torch_end - torch_beg) / 10000 * 1e6 | |
print('Torch time: ', elapsed, 'us') | |
# Repeat for the Triton gemm | |
torch.cuda.synchronize() | |
tl_beg = time.perf_counter() | |
for _ in range(10000): | |
gemm(A, B) | |
# Wait for the last gemm to complete | |
torch.cuda.synchronize() | |
tl_end = time.perf_counter() | |
elapsed = (tl_end - tl_beg) / 10000 * 1e6 | |
print('Triton time: ', elapsed, 'us') |
🤖 Prompt for AI Agents
In examples/jitv2/jitv2.ipynb around lines 415-427, the timing around the GPU
matrix multiplies is unreliable because CUDA launches are asynchronous; add
explicit synchronization calls so the measured intervals include kernel
execution. Specifically, call torch.cuda.synchronize() (guarded by
torch.cuda.is_available()) immediately before starting each perf_counter() and
immediately after each timed loop (both for the A @ B loop and the gemm(A, B)
loop) so the timers bracket actual GPU work.
"torch_beg = time.perf_counter()\n", | ||
"for _ in range(10000):\n", | ||
" A + B\n", | ||
"torch_end = time.perf_counter()\n", | ||
"elapsed = (torch_end - torch_beg) / 10000 * 1e6\n", | ||
"\n", | ||
"print('Torch time: ', elapsed, 'us')\n", | ||
"\n", | ||
"tl_beg = time.perf_counter()\n", | ||
"for _ in range(10000):\n", | ||
" vec_add(A, B)\n", | ||
"tl_end = time.perf_counter()\n", | ||
"elapsed = (tl_end - tl_beg) / 10000 * 1e6\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Synchronize CUDA for accurate timing (Vec Add).
Same issue as GEMM benchmark.
- torch_beg = time.perf_counter()
- for _ in range(10000):
- A + B
- torch_end = time.perf_counter()
+ torch.cuda.synchronize()
+ torch_beg = time.perf_counter()
+ for _ in range(10000):
+ A + B
+ torch.cuda.synchronize()
+ torch_end = time.perf_counter()
@@
- tl_beg = time.perf_counter()
- for _ in range(10000):
- vec_add(A, B)
- tl_end = time.perf_counter()
+ torch.cuda.synchronize()
+ tl_beg = time.perf_counter()
+ for _ in range(10000):
+ vec_add(A, B)
+ torch.cuda.synchronize()
+ tl_end = time.perf_counter()
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
"torch_beg = time.perf_counter()\n", | |
"for _ in range(10000):\n", | |
" A + B\n", | |
"torch_end = time.perf_counter()\n", | |
"elapsed = (torch_end - torch_beg) / 10000 * 1e6\n", | |
"\n", | |
"print('Torch time: ', elapsed, 'us')\n", | |
"\n", | |
"tl_beg = time.perf_counter()\n", | |
"for _ in range(10000):\n", | |
" vec_add(A, B)\n", | |
"tl_end = time.perf_counter()\n", | |
"elapsed = (tl_end - tl_beg) / 10000 * 1e6\n", | |
torch.cuda.synchronize() | |
torch_beg = time.perf_counter() | |
for _ in range(10000): | |
A + B | |
torch.cuda.synchronize() | |
torch_end = time.perf_counter() | |
elapsed = (torch_end - torch_beg) / 10000 * 1e6 | |
print('Torch time: ', elapsed, 'us') | |
torch.cuda.synchronize() | |
tl_beg = time.perf_counter() | |
for _ in range(10000): | |
vec_add(A, B) | |
torch.cuda.synchronize() | |
tl_end = time.perf_counter() | |
elapsed = (tl_end - tl_beg) / 10000 * 1e6 |
🤖 Prompt for AI Agents
In examples/jitv2/jitv2.ipynb around lines 455 to 467, the Vec Add timing
measures asynchronous CUDA launches leading to incorrect timings; synchronize
the CUDA device before starting each timer and again after the loop (e.g.,
torch.cuda.synchronize()) so all GPU work is completed prior to reading the
clock, and apply the same synchronization pattern used in the GEMM benchmark to
both the PyTorch A+B loop and the vec_add(A, B) loop.
" (M, K), (N, K2) = A.shape, B.shape\n", | ||
" assert K == K2, \"Expect matrix A and B to have the same number of columns\"\n", | ||
" C = tl.empty((M, N), dtype=accum_dtype)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix shape destruct in gemm_tune_advanced (B dims order).
B is (K, N). Current destruct makes assert compare K with N.
- (M, K), (N, K2) = A.shape, B.shape
- assert K == K2, "Expect matrix A and B to have the same number of columns"
+ M, K = A.shape
+ K2, N = B.shape
+ assert K == K2, "Expect A.shape[1] == B.shape[0]"
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
" (M, K), (N, K2) = A.shape, B.shape\n", | |
" assert K == K2, \"Expect matrix A and B to have the same number of columns\"\n", | |
" C = tl.empty((M, N), dtype=accum_dtype)\n", | |
M, K = A.shape | |
K2, N = B.shape | |
assert K == K2, "Expect A.shape[1] == B.shape[0]" | |
C = tl.empty((M, N), dtype=accum_dtype) |
🤖 Prompt for AI Agents
In examples/jitv2/jitv2.ipynb around lines 950 to 952, the code destructures
B.shape as (N, K2) but B is actually (K, N); change the destructuring to "(M,
K), (K2, N) = A.shape, B.shape" and keep/ensure the assert as "assert K == K2,
..." so the shared dimension is compared correctly; no change needed to the
creation of C which should remain tl.empty((M, N), dtype=accum_dtype).
@codex review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
source = "" | ||
source += "def parse_args(" + ", ".join(closure.keys()) + "):\n" | ||
source += f" def {fn_name}(" + ", ".join(func_args) + ", __stream__=None" + "):\n" | ||
source += " " + "\n ".join(code_parse_arg) + "\n" | ||
source += " __const_args__ = (" + ", ".join(tup_const) + ")\n" | ||
source += " __dyn_args__ = (" + ", ".join(tup_dyn) + ", __stream__, __device__)\n" | ||
source += " return __const_args__, __dyn_args__\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle zero dynamic arguments when generating arg parser
When generate_arg_parser
builds the source string it always emits __dyn_args__ = ( …, __stream__, __device__)
but it relies on ", ".join(tup_dyn)
to insert the leading values. If a kernel has no dynamic parameters (no tensors or DynSchema
arguments), tup_dyn
is empty and the emitted code becomes __dyn_args__ = (, __stream__, __device__)
, which is invalid Python and causes a SyntaxError
the first time such a kernel is decorated. Consider appending __stream__
and __device__
to the list before joining or adding a conditional to avoid producing an empty tuple prefix.
Useful? React with 👍 / 👎.
# Initial function call and synchronization | ||
def fn(): | ||
if timeout is not None: | ||
run_with_timeout(bench_fn, timeout) | ||
else: | ||
bench_fn() | ||
|
||
fn() | ||
torch.accelerator.synchronize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace nonexistent torch.accelerator.synchronize call
The new benchmarking routine now calls torch.accelerator.synchronize()
right after the warm‑up call. PyTorch does not expose a torch.accelerator
module, so do_bench
will raise AttributeError
before any timing runs (even on CUDA where the previous code used torch.cuda.synchronize()
). This regression makes the default benchmarking utility unusable until the correct backend-specific synchronize (e.g., torch.cuda.synchronize()
or an explicit branch for MPS) is restored.
Useful? React with 👍 / 👎.
Tilelang JITv2
In this PR we introduce Tilelang JITv2, a new frontend for Tilelang with modern and attractive features.
Features
Kernel Declaration
Function declaration has been simplified:
Kernel Call
When calling functions, tensor shapes, strides, and dtypes are automatically inferred:
Auto Tuning
Auto tuning can be done via default arguments:
Or on-the-fly:
Smarter Static Evaluation
JITv2 preserves as much Python code as possible, allowing calls to custom Python functions or conditional kernel generation:
Smarter Type Hinting
JITv2 not only eliminates annoying type warnings, but also adds extensive type annotations. This helps you clearly see each Tensor’s dimensions, and marks whether a value is on the Python side or kernel side. Even generated functions and the JIT-compiled kernels have friendly type hints:
Extremely Low Overhead
JITv2's Python overhead has been optimized to the extreme. In the fast path, only dynamic parameters are checked, bringing overhead in line with calling a torch function (e.g.,
torch.add
):Architecture
The Tilelang JIT workflow:
Static & Dynamic Arguments
JITv2 inspects function signatures to determine which parameters are const and which are dyn:
int
,float
, andptr
; treated astir.Var
Tuple
overList
)Tensor
must be explicitly annotated because itsdata_ptr
is always dynamicint
but pass adict
) — note: validation is hard (like writing apydantic
)Argument Parser
JITv2 generates Python code for the fast path, which unpacks const and dyn arguments and then invokes the kernel:
torch.to_dlpack
K
. More complex asserts may be compiled to host code (not yet supported)Memory Allocation & Return Values
Inside functions, use
T.alloc_global
to create global buffers:T.alloc_global
is friendlier for type linting — it’s translated intotorch.empty
T.alloc_xxx
must be assigned to a variable (x = T.alloc_xxx()
), not passed directly as a function parameter (e.g.,foo(T.alloc_shared(...))
is not allowed)BLOCK_M + BLOCK_N
is not allowed):TODOs
tl.language
Summary by CodeRabbit
New Features
Improvements
Profiler