Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions python/ck4inductor/ck_tile_universal_gemm/gen_instances.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import functools
from .op import CKTileGemmOperation


@functools.cache
def ops():
"""
Generate the supported instance dataclasses
"""
import itertools

compute_v3_instances = [
CKTileGemmOperation(
layout_a=layout_a,
layout_b=layout_b,
layout_c=layout_c,
datatype_a=datatype_a,
datatype_b=datatype_b,
datatype_c=datatype_c,
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
warp_m=warp_m,
warp_n=warp_n,
warp_k=warp_k,
warp_tile_m=warp_tile_m,
warp_tile_n=warp_tile_n,
warp_tile_k=warp_tile_k,
m_is_padded=m_is_padded,
n_is_padded=n_is_padded,
k_is_padded=k_is_padded,
pipeline="CompV3",
scheduler="Intrawave",
epilogue=epilogue,
)
for (layout_a, layout_b, layout_c) in [
("Row", "Row", "Row"),
("Row", "Col", "Row"),
]
for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3]
for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)]
for (warp_m, warp_n, warp_k) in [(2, 2, 1)]
for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)]
for m_is_padded in ["true", "false"]
for n_is_padded in ["true", "false"]
for k_is_padded in ["true", "false"]
for epilogue in ["Default", "CShuffle"]
]

compute_v4_instances = [
CKTileGemmOperation(
layout_a=layout_a,
layout_b=layout_b,
layout_c=layout_c,
datatype_a=datatype_a,
datatype_b=datatype_b,
datatype_c=datatype_c,
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
warp_m=warp_m,
warp_n=warp_n,
warp_k=warp_k,
warp_tile_m=warp_tile_m,
warp_tile_n=warp_tile_n,
warp_tile_k=warp_tile_k,
m_is_padded=m_is_padded,
n_is_padded=n_is_padded,
k_is_padded=k_is_padded,
pipeline="CompV4",
scheduler="Intrawave",
epilogue=epilogue,
)
for (layout_a, layout_b, layout_c) in [
("Row", "Row", "Row"),
("Row", "Col", "Row"),
]
for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3]
for (tile_m, tile_n, tile_k) in [
(256, 256, 32)
] # half the tile size since it has double buffering
for (warp_m, warp_n, warp_k) in [(2, 2, 1)]
for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)]
for m_is_padded in ["true", "false"]
for n_is_padded in ["true", "false"]
for k_is_padded in ["true", "false"]
for epilogue in ["Default", "CShuffle"]
]

mem_instances = [
CKTileGemmOperation(
layout_a=layout_a,
layout_b=layout_b,
layout_c=layout_c,
datatype_a=datatype_a,
datatype_b=datatype_b,
datatype_c=datatype_c,
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
warp_m=warp_m,
warp_n=warp_n,
warp_k=warp_k,
warp_tile_m=warp_tile_m,
warp_tile_n=warp_tile_n,
warp_tile_k=warp_tile_k,
m_is_padded=m_is_padded,
n_is_padded=n_is_padded,
k_is_padded=k_is_padded,
pipeline="Mem",
scheduler=scheduler,
epilogue=epilogue,
)
for (layout_a, layout_b, layout_c) in [
("Row", "Row", "Row"),
("Row", "Col", "Row"),
]
for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3]
for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)]
for (warp_m, warp_n, warp_k) in [(2, 2, 1)]
for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)]
for m_is_padded in ["true", "false"]
for n_is_padded in ["true", "false"]
for k_is_padded in ["true", "false"]
for scheduler in ["Intrawave", "Interwave"]
for epilogue in ["Default", "CShuffle"]
]

return list(
itertools.chain(compute_v3_instances, compute_v4_instances, mem_instances)
)


if __name__ == "__main__":
for op in ops():
print(op.name())
62 changes: 62 additions & 0 deletions python/ck4inductor/ck_tile_universal_gemm/op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from dataclasses import asdict, dataclass


@dataclass
class CKTileGemmOperation:
layout_a: str
layout_b: str
layout_c: str

datatype_a: str
datatype_b: str
datatype_c: str

tile_m: int
tile_n: int
tile_k: int

warp_m: int
warp_n: int
warp_k: int

warp_tile_m: int
warp_tile_n: int
warp_tile_k: int

m_is_padded: str
n_is_padded: str
k_is_padded: str

pipeline: str
scheduler: str
epilogue: str

def layout_repr(self):
return f"{self.layout_a[0]}{self.layout_b[0]}{self.layout_c[0]}"

def dtype_repr(self):
return f"{self.datatype_a}{self.datatype_b}{self.datatype_c}"

def tile_sizes(self):
return "_".join(
[
f"{self.tile_m}{self.tile_n}{self.tile_k}",
f"{self.warp_m}{self.warp_n}{self.warp_k}",
f"{self.warp_tile_m}{self.warp_tile_n}{self.warp_tile_k}",
]
)

def name(self):
return "ck_tile_gemm_universal_" + "_".join(
[
f"{self.layout_repr()}",
f"{self.dtype_repr()}",
f"{self.tile_sizes()}",
f"{self.pipeline}",
f"{self.scheduler}",
f"{self.epilogue}",
]
)

def dict_items(self):
return asdict(self).items()
9 changes: 9 additions & 0 deletions python/test/test_gen_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from ck4inductor.batched_universal_gemm.gen_instances import (
gen_ops_library as gen_batched_gemm_ops_library,
)
from ck4inductor.ck_tile_universal_gemm.gen_instances import (
ops as gen_ck_tile_gemm_ops_library
)

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -44,3 +47,9 @@ def test_gen_batched_gemm_instances(self):

log.debug("%d gemm instances from library" % len(instances))
self.assertTrue(instances)

def test_gen_ck_tile_universal_gemm_instances(self):
instances = gen_ck_tile_gemm_ops_library()

log.debug("%d ck-tile gemm instances from library" % len(instances))
self.assertTrue(instances)
Loading