Skip to content

02. Usage Guide

FindHao edited this page Oct 5, 2025 · 10 revisions

This guide walks you through the complete workflow of using TritonParse to analyze Triton kernel compilation processes.

πŸ“‹ Overview

TritonParse workflow consists of three main steps:

  1. Generate Traces - Capture Triton compilation events
  2. Parse Traces - Process raw logs into structured format
  3. Analyze Results - Visualize and explore using the web interface

πŸš€ Standard Setup Pattern

All TritonParse workflows follow this pattern:

Initialize Logging

import tritonparse.structured_logging

log_path = "./logs/"
tritonparse.structured_logging.init(log_path, enable_trace_launch=True)

Parse Traces

import tritonparse.utils

tritonparse.utils.unified_parse(
    source=log_path,
    out="./parsed_output",
    overwrite=True
)

Alternative - Command Line:

tritonparse parse ./logs/ --out ./parsed_output

πŸš€ Step 1: Generate Triton Trace Files

Example: Complete Triton Kernel

Here's a complete example showing how to trace a Triton kernel:

import torch
import triton
import triton.language as tl
import tritonparse.structured_logging
import tritonparse.utils

# Initialize logging (see Standard Setup Pattern above)
log_path = "./logs/"
tritonparse.structured_logging.init(log_path, enable_trace_launch=True)

@triton.jit
def add_kernel(
    a_ptr,
    b_ptr,
    c_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    a = tl.load(a_ptr + offsets, mask=mask)
    b = tl.load(b_ptr + offsets, mask=mask)
    c = a + b
    tl.store(c_ptr + offsets, c, mask=mask)

def tensor_add(a, b):
    n_elements = a.numel()
    c = torch.empty_like(a)
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
    add_kernel[grid](a, b, c, n_elements, BLOCK_SIZE)
    return c

# Example usage
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    a = torch.randn(1024, 1024, device=device, dtype=torch.float32)
    b = torch.randn(1024, 1024, device=device, dtype=torch.float32)
    
    # Execute kernel (this will be traced)
    c = tensor_add(a, b)
    
    # Parse the generated logs (see Standard Setup Pattern above)
    tritonparse.utils.unified_parse(source=log_path, out="./parsed_output", overwrite=True)

πŸ’‘ Tip: See tests/test_add.py in the repository for a complete runnable example.

PyTorch 2.0+ Compiled Functions

For PyTorch 2.0+ with torch.compile:

import torch
import tritonparse.structured_logging
import tritonparse.utils

# Initialize logging
log_path = "./logs/"
tritonparse.structured_logging.init(log_path, enable_trace_launch=True)

def simple_add(a, b):
    return a + b

# Compile the function
compiled_add = torch.compile(simple_add)

# Execute (this will be traced)
device = "cuda"
a = torch.randn(1024, 1024, device=device, dtype=torch.float32)
b = torch.randn(1024, 1024, device=device, dtype=torch.float32)
result = compiled_add(a, b)

# Parse logs
tritonparse.utils.unified_parse(source=log_path, out="./parsed_output", overwrite=True)

πŸ’‘ Note: Set TORCHINDUCTOR_FX_GRAPH_CACHE=0 to ensure compilation happens every run during testing.

Environment Variables

Configure TritonParse behavior with these environment variables:

Variable Description Example
TRITON_TRACE Trace output directory "./logs/"
TRITON_TRACE_LAUNCH Enable launch tracing ("1" or "0") "1"
TRITONPARSE_DEBUG Enable debug logging "1"
TRITON_TRACE_GZIP Enable gzip compression when logging "1"
TRITONPARSE_KERNEL_ALLOWLIST Filter specific kernels (comma-separated patterns) "my_kernel*,important_*"
TORCHINDUCTOR_FX_GRAPH_CACHE Disable FX graph cache (for testing) "0"

Usage:

export TRITON_TRACE_FOLDER="./logs/"
export TRITON_TRACE_LAUNCH="1"
export TORCHINDUCTOR_FX_GRAPH_CACHE=0

python your_script.py

Running Your Code

# Run with environment variables
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python your_script.py

Expected Output:

Triton kernel executed successfully
Torch compiled function executed successfully
tritonparse log file list: /tmp/tmp1gan7zky/log_file_list.json
INFO:tritonparse:Copying parsed logs from /tmp/tmp1gan7zky to /scratch/findhao/tritonparse/tests/parsed_output

================================================================================
πŸ“ TRITONPARSE PARSING RESULTS
================================================================================
πŸ“‚ Parsed files directory: /scratch/findhao/tritonparse/tests/parsed_output
πŸ“Š Total files generated: 2

πŸ“„ Generated files:
--------------------------------------------------
   1. πŸ“ dedicated_log_triton_trace_findhao__mapped.ndjson.gz (7.2KB)
   2. πŸ“ log_file_list.json (181B)
================================================================================
βœ… Parsing completed successfully!
================================================================================

πŸ”§ Step 2: Parse Trace Files

Python API

import tritonparse.utils

# Basic parsing
tritonparse.utils.unified_parse(
    source="./logs/",           # Input directory with raw logs
    out="./parsed_output",      # Output directory for processed files
    overwrite=True              # Overwrite existing output
)

# Advanced options
tritonparse.utils.unified_parse(
    source="./logs/",
    out="./parsed_output",
    overwrite=True,
    rank=0,                     # Analyze specific rank (for multi-GPU)
    all_ranks=False,            # Or analyze all ranks
    verbose=True                # Enable verbose logging
)

Command Line Interface

# Basic usage
tritonparse parse ./logs/ --out ./parsed_output

# Alternative: using python -m
python -m tritonparse parse ./logs/ --out ./parsed_output

# With options
tritonparse parse ./logs/ --out ./parsed_output --overwrite --verbose

# Multi-GPU: parse specific rank
tritonparse parse ./logs/ --out ./parsed_output --rank 0

# Multi-GPU: parse all ranks
tritonparse parse ./logs/ --out ./parsed_output --all-ranks

🌐 Step 3: Analyze with Web Interface

Option A: Online Interface (Recommended)

  1. Visit the live tool: https://meta-pytorch.org/tritonparse/

  2. Load your trace files:

    • Click "Browse Files" or drag-and-drop
    • Select .gz files from your parsed_output directory
    • Or select .ndjson files from your logs directory
  3. Explore the visualization:

    • Kernel Overview Tab: Kernel metadata, call stack, IR links
    • IR Code View Tab: Side-by-side IR viewing with line mapping

Option B: Local Development Interface

For contributors or custom deployments:

cd website
npm install
npm run dev

Access at http://localhost:5173

Supported File Formats

Format Description Source Mapping Recommended
.gz Compressed parsed traces βœ… Yes βœ… Yes
.ndjson Raw trace logs ❌ No ⚠️ Basic use only

Note: .ndjson files don't contain source code mappings between IR stages and launch diffs. Always use .gz files for full functionality.

πŸ“Š Understanding the Results

Kernel Overview

The overview page shows:

  • Kernel Information: Name, hash, grid/block sizes
  • Compilation Metadata: Device, compile time, memory usage
  • Call Stack: Python source code that triggered compilation
  • IR Navigation: Links to different IR representations
  • Launch Diff: Launch parameters that changed across different launches of the same kernel

IR Code View

The IR code view offers:

  • Side-by-side IR viewing: Compare different compilation stages
  • Synchronized highlighting: Click a line to see corresponding lines in other IRs
  • Source mapping: Trace transformations across compilation pipeline

File Diff View

Compare kernels from two different trace files side-by-side:

  • Cross-trace comparison: Validate optimizations, track kernel evolution, debug differences
  • Flexible modes: Single IR focus or all IRs simultaneously
  • Customizable diff: Ignore whitespace, word/line-level, context control
  • URL shareable: ?view=file_diff&json_url=trace1.gz&json_b_url=trace2.gz

πŸ’‘ Tip: See the Web Interface Guide for detailed File Diff documentation.

IR Stages Explained

Stage Description When Generated
TTGIR Triton GPU IR - High-level GPU operations After Triton frontend
TTIR Triton IR - Language-level operations After parsing
LLIR LLVM IR - Low-level operations After LLVM conversion
PTX NVIDIA PTX Assembly For NVIDIA GPUs
AMDGCN AMD GPU Assembly For AMD GPUs

πŸš€ Launch Analysis

TritonParse can analyze kernel launch parameters to identify variations and commonalities across different launches of the same kernel. This is useful for understanding how dynamic shapes or other factors affect kernel execution.

How it Works

  1. Enable Launch Tracing: You must enable launch tracing during the trace generation step. This is done by passing enable_trace_launch=True to tritonparse.structured_logging.init().
  2. Parsing: During the parsing step (tritonparse.utils.unified_parse), TritonParse will automatically group all launches for each kernel.
  3. Launch Diff Event: A new event of type launch_diff is generated for each kernel. This event contains:
    • total_launches: The total number of times the kernel was launched.
    • diffs: A dictionary showing which launch parameters (e.g., grid_x, grid_y) changed across launches and what their different values were.
    • sames: A dictionary showing which launch parameters remained constant across all launches.
    • launch_index_map: A mapping from the launch index to the original line number in the trace file.

Example launch_diff Event

{
  "event_type": "launch_diff",
  "hash": "...",
  "name": "triton_kernel_name",
  "total_launches": 10,
  "launch_index_map": { "0": 15, "1": 25, ... },
  "diffs": {
    "grid_x": [1024, 2048]
  },
  "sames": {
    "grid_y": 1,
    "grid_z": 1,
    "stream": 7
  }
}

This example shows that grid_x varied between 1024 and 2048 across 10 launches, while other parameters remained the same.

πŸ”§ Reproducer - Generate Standalone Kernel Scripts

TritonParse can automatically generate standalone Python scripts that reproduce specific kernel executions. Useful for debugging, sharing test cases, and isolating performance issues.

Quick Start

Command Line:

# Generate reproducer for first launch event
tritonparse reproduce ./parsed_output/trace.ndjson --line 1 --out-dir repro_output

# Using compressed files
tritonparse reproduce ./parsed_output/trace.ndjson.gz --line 5 --out-dir my_repro

# With custom template
tritonparse reproduce trace.ndjson --line 1 --template /path/to/my_template.py

Python API:

from tritonparse.reproducer.orchestrator import reproduce

result = reproduce(
    input_path="./parsed_output/trace.ndjson",
    line_index=1,                    # Which launch event (1-based)
    out_dir="repro_output",
    template="example"               # Built-in template
)

print(f"Script: {result['repo_script']}")
print(f"Context: {result['repo_context']}")

Generated Files

repro_output/<kernel_name>/
β”œβ”€β”€ repro_<timestamp>.py              # Standalone executable script
β”œβ”€β”€ repro_context_<timestamp>.json    # Kernel metadata and parameters
└── <hash>.bin                        # Tensor blobs (if enabled during tracing)

Parameters

Parameter Description Default
input Path to NDJSON trace file (.ndjson or .ndjson.gz) Required
--line Line number (1-based) of launch event 1
--out-dir Output directory repro_output/<kernel>/
--template Template name or path example

Tensor Data Strategies

The reproducer reconstructs tensors using one of these strategies:

1. Blob Files (Highest Fidelity)

# Enable during tracing
tritonparse.structured_logging.init(
    "./logs/",
    enable_trace_launch=True,
    save_tensor_blobs=True  # Save actual tensor data
)

2. Statistical Reconstruction (Good Approximation)

  • Uses saved mean, std, min, max to generate similar data
  • Matches shape, dtype, device of original

3. Random Data (Fallback)

  • Random generation matching only shape and dtype

Common Use Cases

Bug Isolation:

tritonparse reproduce trace.ndjson --line 42 --out-dir bug_repro
cd bug_repro && python repro_*.py

Performance Benchmarking:

tritonparse reproduce trace.ndjson --line 1 --out-dir benchmark
# Modify script to add timing

Kernel Comparison:

tritonparse reproduce trace_v1.ndjson --line 1 --out-dir v1
tritonparse reproduce trace_v2.ndjson --line 1 --out-dir v2
# Compare outputs and performance

Custom Templates

Create your own template:

# my_template.py
"""Custom reproducer template"""
import torch
# {{KERNEL_IMPORT_PLACEHOLDER}}

if __name__ == "__main__":
    # {{KERNEL_INVOCATION_PLACEHOLDER}}
    print("Custom execution complete!")

Available Placeholders:

  • {{KERNEL_IMPORT_PLACEHOLDER}} - Kernel imports
  • {{KERNEL_INVOCATION_PLACEHOLDER}} - Launch code
  • {{KERNEL_SYSPATH_PLACEHOLDER}} - System path setup
  • {{JSON_FILE_NAME_PLACEHOLDER}} - Context JSON filename

Usage:

tritonparse reproduce trace.ndjson --line 1 --template /path/to/my_template.py

Advanced: Custom Types Support

For triton_kernels projects:

from triton_kernels.tensor import Tensor, Storage, StridedLayout
# Reproducer automatically handles these if triton_kernels is installed

If not installed:

RuntimeError: Optional dependency 'triton_kernels.tensor' is not installed

Solution: pip install triton_kernels


πŸ”„ Initialization Methods Comparison

TritonParse supports two initialization methods:

Method 1: Direct Initialization

tritonparse.structured_logging.init(
    trace_folder="./logs/",
    enable_trace_launch=True
)

Method 2: Environment Variables

import os
os.environ["TRITON_TRACE_FOLDER"] = "./logs/"
os.environ["TRITON_TRACE_LAUNCH"] = "1"

tritonparse.structured_logging.init_with_env()

Or from shell:

export TRITON_TRACE_FOLDER="./logs/"
export TRITON_TRACE_LAUNCH="1"
python my_script.py

Comparison Table

Method When to Use Pros Cons
init(path, enable_trace_launch=True) Direct control in code Explicit, type-safe Hardcoded paths
init_with_env() Environment-based config Flexible, CI/CD friendly Requires env setup

πŸ” Advanced Features

Kernel Filtering

Trace only specific kernels:

export TRITONPARSE_KERNEL_ALLOWLIST="my_kernel*,important_*"
python your_script.py

Multi-GPU Analysis

Parse all ranks:

tritonparse.utils.unified_parse(
    source="./logs/",
    out="./parsed_output",
    all_ranks=True
)

Parse specific rank:

tritonparse.utils.unified_parse(
    source="./logs/",
    out="./parsed_output",
    rank=1
)

Command line:

tritonparse parse ./logs/ --out ./parsed_output --all-ranks
tritonparse parse ./logs/ --out ./parsed_output --rank 1

οΏ½ Reference

Supported Data Types

Floating point: float32, float16, bfloat16, float8_e4m3fn, float8_e5m2
Integers: int8, int16, int32, int64, uint8
Complex: complex64, complex128
Boolean: bool
Custom: triton_kernels.tensor.Tensor (requires triton_kernels)

File Formats

Format Source Mapping Launch Diffs Recommended
.ndjson.gz βœ… Yes βœ… Yes βœ… Yes
.ndjson ❌ No ❌ No ⚠️ Basic only

Always use .gz files from parsed_output/ for full functionality.

IR Stages

Stage Description When Generated
TTIR Triton IR - Language-level operations After parsing
TTGIR Triton GPU IR - GPU-specific operations After Triton frontend
LLIR LLVM IR - Low-level operations After LLVM conversion
PTX NVIDIA PTX Assembly For NVIDIA GPUs
AMDGCN AMD GPU Assembly For AMD GPUs
Clone this wiki locally