-
Notifications
You must be signed in to change notification settings - Fork 11
02. Usage Guide
This guide walks you through the complete workflow of using TritonParse to analyze Triton kernel compilation processes.
TritonParse workflow consists of three main steps:
- Generate Traces - Capture Triton compilation events
- Parse Traces - Process raw logs into structured format
- Analyze Results - Visualize and explore using the web interface
All TritonParse workflows follow this pattern:
import tritonparse.structured_logging
log_path = "./logs/"
tritonparse.structured_logging.init(log_path, enable_trace_launch=True)
import tritonparse.utils
tritonparse.utils.unified_parse(
source=log_path,
out="./parsed_output",
overwrite=True
)
Alternative - Command Line:
tritonparse parse ./logs/ --out ./parsed_output
Here's a complete example showing how to trace a Triton kernel:
import torch
import triton
import triton.language as tl
import tritonparse.structured_logging
import tritonparse.utils
# Initialize logging (see Standard Setup Pattern above)
log_path = "./logs/"
tritonparse.structured_logging.init(log_path, enable_trace_launch=True)
@triton.jit
def add_kernel(
a_ptr,
b_ptr,
c_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
c = a + b
tl.store(c_ptr + offsets, c, mask=mask)
def tensor_add(a, b):
n_elements = a.numel()
c = torch.empty_like(a)
BLOCK_SIZE = 1024
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
add_kernel[grid](a, b, c, n_elements, BLOCK_SIZE)
return c
# Example usage
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
a = torch.randn(1024, 1024, device=device, dtype=torch.float32)
b = torch.randn(1024, 1024, device=device, dtype=torch.float32)
# Execute kernel (this will be traced)
c = tensor_add(a, b)
# Parse the generated logs (see Standard Setup Pattern above)
tritonparse.utils.unified_parse(source=log_path, out="./parsed_output", overwrite=True)
π‘ Tip: See
tests/test_add.py
in the repository for a complete runnable example.
For PyTorch 2.0+ with torch.compile
:
import torch
import tritonparse.structured_logging
import tritonparse.utils
# Initialize logging
log_path = "./logs/"
tritonparse.structured_logging.init(log_path, enable_trace_launch=True)
def simple_add(a, b):
return a + b
# Compile the function
compiled_add = torch.compile(simple_add)
# Execute (this will be traced)
device = "cuda"
a = torch.randn(1024, 1024, device=device, dtype=torch.float32)
b = torch.randn(1024, 1024, device=device, dtype=torch.float32)
result = compiled_add(a, b)
# Parse logs
tritonparse.utils.unified_parse(source=log_path, out="./parsed_output", overwrite=True)
π‘ Note: Set
TORCHINDUCTOR_FX_GRAPH_CACHE=0
to ensure compilation happens every run during testing.
Configure TritonParse behavior with these environment variables:
Variable | Description | Example |
---|---|---|
TRITON_TRACE |
Trace output directory | "./logs/" |
TRITON_TRACE_LAUNCH |
Enable launch tracing ("1" or "0" ) |
"1" |
TRITONPARSE_DEBUG |
Enable debug logging | "1" |
TRITON_TRACE_GZIP |
Enable gzip compression when logging | "1" |
TRITONPARSE_KERNEL_ALLOWLIST |
Filter specific kernels (comma-separated patterns) | "my_kernel*,important_*" |
TORCHINDUCTOR_FX_GRAPH_CACHE |
Disable FX graph cache (for testing) | "0" |
Usage:
export TRITON_TRACE_FOLDER="./logs/"
export TRITON_TRACE_LAUNCH="1"
export TORCHINDUCTOR_FX_GRAPH_CACHE=0
python your_script.py
# Run with environment variables
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python your_script.py
Expected Output:
Triton kernel executed successfully
Torch compiled function executed successfully
tritonparse log file list: /tmp/tmp1gan7zky/log_file_list.json
INFO:tritonparse:Copying parsed logs from /tmp/tmp1gan7zky to /scratch/findhao/tritonparse/tests/parsed_output
================================================================================
π TRITONPARSE PARSING RESULTS
================================================================================
π Parsed files directory: /scratch/findhao/tritonparse/tests/parsed_output
π Total files generated: 2
π Generated files:
--------------------------------------------------
1. π dedicated_log_triton_trace_findhao__mapped.ndjson.gz (7.2KB)
2. π log_file_list.json (181B)
================================================================================
β
Parsing completed successfully!
================================================================================
import tritonparse.utils
# Basic parsing
tritonparse.utils.unified_parse(
source="./logs/", # Input directory with raw logs
out="./parsed_output", # Output directory for processed files
overwrite=True # Overwrite existing output
)
# Advanced options
tritonparse.utils.unified_parse(
source="./logs/",
out="./parsed_output",
overwrite=True,
rank=0, # Analyze specific rank (for multi-GPU)
all_ranks=False, # Or analyze all ranks
verbose=True # Enable verbose logging
)
# Basic usage
tritonparse parse ./logs/ --out ./parsed_output
# Alternative: using python -m
python -m tritonparse parse ./logs/ --out ./parsed_output
# With options
tritonparse parse ./logs/ --out ./parsed_output --overwrite --verbose
# Multi-GPU: parse specific rank
tritonparse parse ./logs/ --out ./parsed_output --rank 0
# Multi-GPU: parse all ranks
tritonparse parse ./logs/ --out ./parsed_output --all-ranks
-
Visit the live tool: https://meta-pytorch.org/tritonparse/
-
Load your trace files:
- Click "Browse Files" or drag-and-drop
- Select
.gz
files from yourparsed_output
directory - Or select
.ndjson
files from yourlogs
directory
-
Explore the visualization:
- Kernel Overview Tab: Kernel metadata, call stack, IR links
- IR Code View Tab: Side-by-side IR viewing with line mapping
For contributors or custom deployments:
cd website
npm install
npm run dev
Access at http://localhost:5173
Format | Description | Source Mapping | Recommended |
---|---|---|---|
.gz |
Compressed parsed traces | β Yes | β Yes |
.ndjson |
Raw trace logs | β No |
Note: .ndjson
files don't contain source code mappings between IR stages and launch diffs. Always use .gz
files for full functionality.
The overview page shows:
- Kernel Information: Name, hash, grid/block sizes
- Compilation Metadata: Device, compile time, memory usage
- Call Stack: Python source code that triggered compilation
- IR Navigation: Links to different IR representations
- Launch Diff: Launch parameters that changed across different launches of the same kernel
The IR code view offers:
- Side-by-side IR viewing: Compare different compilation stages
- Synchronized highlighting: Click a line to see corresponding lines in other IRs
- Source mapping: Trace transformations across compilation pipeline
Compare kernels from two different trace files side-by-side:
- Cross-trace comparison: Validate optimizations, track kernel evolution, debug differences
- Flexible modes: Single IR focus or all IRs simultaneously
- Customizable diff: Ignore whitespace, word/line-level, context control
-
URL shareable:
?view=file_diff&json_url=trace1.gz&json_b_url=trace2.gz
π‘ Tip: See the Web Interface Guide for detailed File Diff documentation.
Stage | Description | When Generated |
---|---|---|
TTGIR | Triton GPU IR - High-level GPU operations | After Triton frontend |
TTIR | Triton IR - Language-level operations | After parsing |
LLIR | LLVM IR - Low-level operations | After LLVM conversion |
PTX | NVIDIA PTX Assembly | For NVIDIA GPUs |
AMDGCN | AMD GPU Assembly | For AMD GPUs |
TritonParse can analyze kernel launch parameters to identify variations and commonalities across different launches of the same kernel. This is useful for understanding how dynamic shapes or other factors affect kernel execution.
-
Enable Launch Tracing: You must enable launch tracing during the trace generation step. This is done by passing
enable_trace_launch=True
totritonparse.structured_logging.init()
. -
Parsing: During the parsing step (
tritonparse.utils.unified_parse
), TritonParse will automatically group all launches for each kernel. -
Launch Diff Event: A new event of type
launch_diff
is generated for each kernel. This event contains:-
total_launches
: The total number of times the kernel was launched. -
diffs
: A dictionary showing which launch parameters (e.g.,grid_x
,grid_y
) changed across launches and what their different values were. -
sames
: A dictionary showing which launch parameters remained constant across all launches. -
launch_index_map
: A mapping from the launch index to the original line number in the trace file.
-
{
"event_type": "launch_diff",
"hash": "...",
"name": "triton_kernel_name",
"total_launches": 10,
"launch_index_map": { "0": 15, "1": 25, ... },
"diffs": {
"grid_x": [1024, 2048]
},
"sames": {
"grid_y": 1,
"grid_z": 1,
"stream": 7
}
}
This example shows that grid_x
varied between 1024
and 2048
across 10 launches, while other parameters remained the same.
TritonParse can automatically generate standalone Python scripts that reproduce specific kernel executions. Useful for debugging, sharing test cases, and isolating performance issues.
Command Line:
# Generate reproducer for first launch event
tritonparse reproduce ./parsed_output/trace.ndjson --line 1 --out-dir repro_output
# Using compressed files
tritonparse reproduce ./parsed_output/trace.ndjson.gz --line 5 --out-dir my_repro
# With custom template
tritonparse reproduce trace.ndjson --line 1 --template /path/to/my_template.py
Python API:
from tritonparse.reproducer.orchestrator import reproduce
result = reproduce(
input_path="./parsed_output/trace.ndjson",
line_index=1, # Which launch event (1-based)
out_dir="repro_output",
template="example" # Built-in template
)
print(f"Script: {result['repo_script']}")
print(f"Context: {result['repo_context']}")
repro_output/<kernel_name>/
βββ repro_<timestamp>.py # Standalone executable script
βββ repro_context_<timestamp>.json # Kernel metadata and parameters
βββ <hash>.bin # Tensor blobs (if enabled during tracing)
Parameter | Description | Default |
---|---|---|
input |
Path to NDJSON trace file (.ndjson or .ndjson.gz ) |
Required |
--line |
Line number (1-based) of launch event | 1 |
--out-dir |
Output directory | repro_output/<kernel>/ |
--template |
Template name or path | example |
The reproducer reconstructs tensors using one of these strategies:
1. Blob Files (Highest Fidelity)
# Enable during tracing
tritonparse.structured_logging.init(
"./logs/",
enable_trace_launch=True,
save_tensor_blobs=True # Save actual tensor data
)
2. Statistical Reconstruction (Good Approximation)
- Uses saved mean, std, min, max to generate similar data
- Matches shape, dtype, device of original
3. Random Data (Fallback)
- Random generation matching only shape and dtype
Bug Isolation:
tritonparse reproduce trace.ndjson --line 42 --out-dir bug_repro
cd bug_repro && python repro_*.py
Performance Benchmarking:
tritonparse reproduce trace.ndjson --line 1 --out-dir benchmark
# Modify script to add timing
Kernel Comparison:
tritonparse reproduce trace_v1.ndjson --line 1 --out-dir v1
tritonparse reproduce trace_v2.ndjson --line 1 --out-dir v2
# Compare outputs and performance
Create your own template:
# my_template.py
"""Custom reproducer template"""
import torch
# {{KERNEL_IMPORT_PLACEHOLDER}}
if __name__ == "__main__":
# {{KERNEL_INVOCATION_PLACEHOLDER}}
print("Custom execution complete!")
Available Placeholders:
-
{{KERNEL_IMPORT_PLACEHOLDER}}
- Kernel imports -
{{KERNEL_INVOCATION_PLACEHOLDER}}
- Launch code -
{{KERNEL_SYSPATH_PLACEHOLDER}}
- System path setup -
{{JSON_FILE_NAME_PLACEHOLDER}}
- Context JSON filename
Usage:
tritonparse reproduce trace.ndjson --line 1 --template /path/to/my_template.py
For triton_kernels
projects:
from triton_kernels.tensor import Tensor, Storage, StridedLayout
# Reproducer automatically handles these if triton_kernels is installed
If not installed:
RuntimeError: Optional dependency 'triton_kernels.tensor' is not installed
Solution: pip install triton_kernels
TritonParse supports two initialization methods:
tritonparse.structured_logging.init(
trace_folder="./logs/",
enable_trace_launch=True
)
import os
os.environ["TRITON_TRACE_FOLDER"] = "./logs/"
os.environ["TRITON_TRACE_LAUNCH"] = "1"
tritonparse.structured_logging.init_with_env()
Or from shell:
export TRITON_TRACE_FOLDER="./logs/"
export TRITON_TRACE_LAUNCH="1"
python my_script.py
Method | When to Use | Pros | Cons |
---|---|---|---|
init(path, enable_trace_launch=True) |
Direct control in code | Explicit, type-safe | Hardcoded paths |
init_with_env() |
Environment-based config | Flexible, CI/CD friendly | Requires env setup |
Trace only specific kernels:
export TRITONPARSE_KERNEL_ALLOWLIST="my_kernel*,important_*"
python your_script.py
Parse all ranks:
tritonparse.utils.unified_parse(
source="./logs/",
out="./parsed_output",
all_ranks=True
)
Parse specific rank:
tritonparse.utils.unified_parse(
source="./logs/",
out="./parsed_output",
rank=1
)
Command line:
tritonparse parse ./logs/ --out ./parsed_output --all-ranks
tritonparse parse ./logs/ --out ./parsed_output --rank 1
Floating point: float32
, float16
, bfloat16
, float8_e4m3fn
, float8_e5m2
Integers: int8
, int16
, int32
, int64
, uint8
Complex: complex64
, complex128
Boolean: bool
Custom: triton_kernels.tensor.Tensor
(requires triton_kernels
)
Format | Source Mapping | Launch Diffs | Recommended |
---|---|---|---|
.ndjson.gz |
β Yes | β Yes | β Yes |
.ndjson |
β No | β No |
Always use
.gz
files fromparsed_output/
for full functionality.
Stage | Description | When Generated |
---|---|---|
TTIR | Triton IR - Language-level operations | After parsing |
TTGIR | Triton GPU IR - GPU-specific operations | After Triton frontend |
LLIR | LLVM IR - Low-level operations | After LLVM conversion |
PTX | NVIDIA PTX Assembly | For NVIDIA GPUs |
AMDGCN | AMD GPU Assembly | For AMD GPUs |