Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions benchmark/examples/benchmark_all_gather_gemm_pull.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def parse_args():
parser.add_argument("--BLK_N", type=int, default=64, help="Block size N for the kernel")
parser.add_argument("--BLK_K", type=int, default=64, help="Block size K for the kernel")
parser.add_argument("--gsize_m", type=int, default=6, help="Group size in M dimension")
parser.add_argument("--num_sms", type=int, default=304, help="Number of SMs for the kernel")
parser.add_argument(
"--num_sms", type=int, default=None, help="Number of SMs for the kernel (default: auto-detected)"
)

parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run the example on.")

Expand Down Expand Up @@ -138,7 +140,12 @@ def worker(rank: int, world_size: int, init_url: str, args: argparse.Namespace):
A_local_iris = shmem.empty((M, K_local), dtype=datatype)
A_local_iris.copy_(A_local)

num_sms = torch.cuda.get_device_properties(rank).multi_processor_count
# Use provided num_sms or auto-detect
if run_args["num_sms"] is None:
num_sms = torch.cuda.get_device_properties(rank).multi_processor_count
run_args["num_sms"] = num_sms
else:
num_sms = run_args["num_sms"]

main_stream = torch.cuda.Stream()
kernel_timing = {
Expand Down
11 changes: 9 additions & 2 deletions benchmark/examples/benchmark_all_gather_gemm_push.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def parse_args():
parser.add_argument("--BLK_N", type=int, default=64, help="Block size N for GEMM computation")
parser.add_argument("--BLK_K", type=int, default=64, help="Block size K for tiling")
parser.add_argument("--gsize_m", type=int, default=6, help="Group size in M dimension")
parser.add_argument("--num_sms", type=int, default=304, help="Number of SMs for the kernel")
parser.add_argument(
"--num_sms", type=int, default=None, help="Number of SMs for the kernel (default: auto-detected)"
)

parser.add_argument("--num_ranks", type=int, default=8, help="Number of GPUs to run the example on.")

Expand Down Expand Up @@ -142,7 +144,12 @@ def worker(rank: int, world_size: int, init_url: str, args: argparse.Namespace):
num_k_tiles = (K_local + run_args["BLK_K"] - 1) // run_args["BLK_K"]
signal_flags_iris = shmem.zeros((world_size, world_size, num_m_tiles, num_k_tiles), dtype=torch.int32)

num_sms = torch.cuda.get_device_properties(rank).multi_processor_count
# Use provided num_sms or auto-detect
if run_args["num_sms"] is None:
num_sms = torch.cuda.get_device_properties(rank).multi_processor_count
run_args["num_sms"] = num_sms
else:
num_sms = run_args["num_sms"]

main_stream = torch.cuda.Stream()
kernel_timing = {
Expand Down
14 changes: 12 additions & 2 deletions examples/07_gemm_all_scatter/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ def parse_args():
parser.add_argument("--BLK_K", type=int, default=64, help="Block size K")
parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter")
parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size")
parser.add_argument("--gemm_sms", type=int, default=304, help="Number of SMs for persistent GEMM algorithm")
parser.add_argument(
"--gemm_sms",
type=int,
default=None,
help="Number of SMs for persistent GEMM algorithm (default: auto-detected)",
)
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")

return vars(parser.parse_args())
Expand All @@ -67,7 +72,12 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
shmem = iris.iris(args["heap_size"])
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()

# Set default SM values if not provided
if args["gemm_sms"] is None:
# For all_scatter: use total CU count
cu_count = torch.cuda.get_device_properties(rank).multi_processor_count
args["gemm_sms"] = cu_count

# GEMM
datatype = torch.float32
Expand Down
16 changes: 11 additions & 5 deletions examples/08_gemm_atomics_all_reduce/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import argparse
import json
import math

from examples.common.utils import (
JSONWriter,
Expand Down Expand Up @@ -68,10 +69,8 @@ def parse_args():
parser.add_argument("--kpack", type=int, default=2, help="K packing size")
parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size")

# For All Scatter, use: 288
# For One Shot, use: 256
parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM")
parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs")
parser.add_argument("--gemm_sms", type=int, default=None, help="Number of SMs for GEMM (default: auto-detected)")
parser.add_argument("--total_sms", type=int, default=None, help="Total number of SMs (default: auto-detected)")
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")

return vars(parser.parse_args())
Expand All @@ -86,7 +85,14 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
shmem = iris.iris(args["heap_size"])
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()

# Set default SM values if not provided
cu_count = torch.cuda.get_device_properties(rank).multi_processor_count
if args["total_sms"] is None:
args["total_sms"] = cu_count
if args["gemm_sms"] is None:
# For all_reduce: use next smaller power of 2, rest for communication
args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1

# GEMM
datatype = torch.float32
Expand Down
14 changes: 11 additions & 3 deletions examples/09_gemm_one_shot_all_reduce/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import argparse
import json
import math

from examples.common.utils import (
JSONWriter,
Expand Down Expand Up @@ -68,8 +69,8 @@ def parse_args():
parser.add_argument("--kpack", type=int, default=2, help="K packing size")
parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size")

parser.add_argument("--gemm_sms", type=int, default=288, help="Number of SMs for GEMM")
parser.add_argument("--total_sms", type=int, default=304, help="Total number of SMs")
parser.add_argument("--gemm_sms", type=int, default=None, help="Number of SMs for GEMM (default: auto-detected)")
parser.add_argument("--total_sms", type=int, default=None, help="Total number of SMs (default: auto-detected)")
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")
return vars(parser.parse_args())

Expand All @@ -82,7 +83,14 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
shmem = iris.iris(args["heap_size"])
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()

# Set default SM values if not provided
cu_count = torch.cuda.get_device_properties(rank).multi_processor_count
if args["total_sms"] is None:
args["total_sms"] = cu_count
if args["gemm_sms"] is None:
# For all_reduce: use next smaller power of 2, rest for communication
args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1

# GEMM
datatype = torch.float32
Expand Down
22 changes: 19 additions & 3 deletions examples/10_gemm_all_scatter_wg_specialization/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import argparse
import json
import math

from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set
from examples.common.validation import validate_gemm
Expand Down Expand Up @@ -54,9 +55,17 @@ def parse_args():
parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter")
parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size")
parser.add_argument(
"--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm"
"--gemm_sms",
type=int,
default=None,
help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)",
)
parser.add_argument(
"--num_sms",
type=int,
default=None,
help="Number of total SMs for gemm + scatter kernel (default: auto-detected)",
)
parser.add_argument("--num_sms", type=int, default=304, help="Number of total SMs for gemm + scatter kernel")
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")

return vars(parser.parse_args())
Expand All @@ -70,7 +79,14 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
shmem = iris.iris(args["heap_size"])
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()

# Set default SM values if not provided
cu_count = torch.cuda.get_device_properties(rank).multi_processor_count
if args["num_sms"] is None:
args["num_sms"] = cu_count
if args["gemm_sms"] is None:
# For wg_specialized: use next smaller power of 2
args["gemm_sms"] = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1

# GEMM
datatype = torch.float32
Expand Down
22 changes: 19 additions & 3 deletions examples/11_gemm_all_scatter_producer_consumer/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import argparse
import json
import math

from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set
from examples.common.validation import validate_gemm
Expand Down Expand Up @@ -55,9 +56,14 @@ def parse_args():
parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter")
parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size")
parser.add_argument(
"--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm"
"--gemm_sms",
type=int,
default=None,
help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)",
)
parser.add_argument(
"--comm_sms", type=int, default=None, help="Number of SMs for All-Scatter kernel (default: auto-detected)"
)
parser.add_argument("--comm_sms", type=int, default=48, help="Number of SMs for All-Scatter kernel")
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")

return vars(parser.parse_args())
Expand All @@ -71,7 +77,17 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
shmem = iris.iris(args["heap_size"])
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()

# Set default SM values if not provided
cu_count = torch.cuda.get_device_properties(rank).multi_processor_count
next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1

if args["gemm_sms"] is None:
# For wg_specialized: use next smaller power of 2
args["gemm_sms"] = next_pow2
if args["comm_sms"] is None:
# comm_sms is the leftover: total - next_power_of_2
args["comm_sms"] = cu_count - next_pow2

# GEMM
datatype = torch.float32
Expand Down
22 changes: 19 additions & 3 deletions examples/12_gemm_all_scatter_bulk_synchronous/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import argparse
import json
import math

from examples.common.utils import JSONWriter, Timestamps, is_triton_interpret_set
from examples.common.validation import validate_gemm
Expand Down Expand Up @@ -55,9 +56,14 @@ def parse_args():
parser.add_argument("--gsize_m", type=int, default=6, help="L2-cache locality swizzle parameter")
parser.add_argument("--heap_size", type=int, default=1 << 33, help="Iris heap size")
parser.add_argument(
"--gemm_sms", type=int, default=256, help="Number of SMs for workgroup-specialized GEMM algorithm"
"--gemm_sms",
type=int,
default=None,
help="Number of SMs for workgroup-specialized GEMM algorithm (default: auto-detected)",
)
parser.add_argument(
"--comm_sms", type=int, default=None, help="Number of SMs for All-Scatter kernel (default: auto-detected)"
)
parser.add_argument("--comm_sms", type=int, default=256, help="Number of SMs for All-Scatter kernel")
parser.add_argument("-r", "--num_ranks", type=int, default=2, help="Number of ranks/processes")

return vars(parser.parse_args())
Expand All @@ -71,7 +77,17 @@ def _worker(local_rank: int, world_size: int, init_url: str, args: dict):
shmem = iris.iris(args["heap_size"])
rank = shmem.get_rank()
world_size = shmem.get_num_ranks()
cu_count = shmem.get_cu_count()

# Set default SM values if not provided
cu_count = torch.cuda.get_device_properties(rank).multi_processor_count
next_pow2 = 2 ** int(math.log2(cu_count)) if cu_count > 0 else 1

if args["gemm_sms"] is None:
# For wg_specialized: use next smaller power of 2
args["gemm_sms"] = next_pow2
if args["comm_sms"] is None:
# For bulk synchronous, use same as gemm_sms
args["comm_sms"] = next_pow2

# GEMM
datatype = torch.float32
Expand Down
31 changes: 22 additions & 9 deletions examples/benchmark/bench_all_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import datetime
import argparse
import json
import torch


def launch_sbatch(
Expand Down Expand Up @@ -110,16 +111,28 @@ def main(hashes, config, sbatch_script_content, input_json, tiling_json, dry_run
if mkn not in mkn_gemm_tiles:
mkn_gemm_tiles[mkn] = {key: entry[key] for key in optional_keys if key in entry}

if config["partition"] is not None:
if "mi300" in config["partition"]:
print("Running on MI300")
# Determine gemm_sms based on available GPU or partition name
try:
if torch.cuda.is_available():
gemm_sms = torch.cuda.get_device_properties(0).multi_processor_count
print(f"Auto-detected CU count: {gemm_sms}")
else:
gemm_sms = None
except Exception:
# Fall back to partition-based detection
gemm_sms = None

if gemm_sms is None:
if config["partition"] is not None:
if "mi300" in config["partition"]:
print("Running on MI300 (partition-based)")
gemm_sms = 304
elif "mi250" in config["partition"]:
print("Running on MI250 (partition-based)")
gemm_sms = 104
else:
print("Assuming MI300 (default)")
gemm_sms = 304
elif "mi250" in config["partition"]:
print("Running on MI250")
gemm_sms = 104
else:
print("Assuming MI300")
gemm_sms = 304

enable_algorithms = False
enable_mkn = True
Expand Down
11 changes: 10 additions & 1 deletion scripts/link_bandwidth.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@
# Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.

import json
import torch

try:
if torch.cuda.is_available():
cu_count = torch.cuda.get_device_properties(0).multi_processor_count
else:
cu_count = 304 # Default for MI300
except Exception:
cu_count = 304 # Default for MI300

# Sample input (replace with file read if needed)
config = {
Expand All @@ -26,7 +35,7 @@
"kpack": 2,
"heap_size": 8589934592,
"gemm_sms": 48,
"total_sms": 304,
"total_sms": cu_count,
"communication_block_size": 256,
"communication_sms_multiplier": 1,
"M": 8192,
Expand Down
Loading