diff --git a/python/ck4inductor/ck_tile_universal_gemm/gen_instances.py b/python/ck4inductor/ck_tile_universal_gemm/gen_instances.py new file mode 100644 index 00000000000..ce9e7379345 --- /dev/null +++ b/python/ck4inductor/ck_tile_universal_gemm/gen_instances.py @@ -0,0 +1,136 @@ +import functools +from .op import CKTileGemmOperation + + +@functools.cache +def ops(): + """ + Generate the supported instance dataclasses + """ + import itertools + + compute_v3_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="CompV3", + scheduler="Intrawave", + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)] + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for epilogue in ["Default", "CShuffle"] + ] + + compute_v4_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="CompV4", + scheduler="Intrawave", + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [ + (256, 256, 32) + ] # half the tile size since it has double buffering + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for epilogue in ["Default", "CShuffle"] + ] + + mem_instances = [ + CKTileGemmOperation( + layout_a=layout_a, + layout_b=layout_b, + layout_c=layout_c, + datatype_a=datatype_a, + datatype_b=datatype_b, + datatype_c=datatype_c, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=warp_k, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + m_is_padded=m_is_padded, + n_is_padded=n_is_padded, + k_is_padded=k_is_padded, + pipeline="Mem", + scheduler=scheduler, + epilogue=epilogue, + ) + for (layout_a, layout_b, layout_c) in [ + ("Row", "Row", "Row"), + ("Row", "Col", "Row"), + ] + for (datatype_a, datatype_b, datatype_c) in [("FP16",) * 3, ("BF16",) * 3] + for (tile_m, tile_n, tile_k) in [(256, 256, 32), (256, 256, 64)] + for (warp_m, warp_n, warp_k) in [(2, 2, 1)] + for (warp_tile_m, warp_tile_n, warp_tile_k) in [(32, 32, 16)] + for m_is_padded in ["true", "false"] + for n_is_padded in ["true", "false"] + for k_is_padded in ["true", "false"] + for scheduler in ["Intrawave", "Interwave"] + for epilogue in ["Default", "CShuffle"] + ] + + return list( + itertools.chain(compute_v3_instances, compute_v4_instances, mem_instances) + ) + + +if __name__ == "__main__": + for op in ops(): + print(op.name()) diff --git a/python/ck4inductor/ck_tile_universal_gemm/op.py b/python/ck4inductor/ck_tile_universal_gemm/op.py new file mode 100644 index 00000000000..0efbcd8b7c3 --- /dev/null +++ b/python/ck4inductor/ck_tile_universal_gemm/op.py @@ -0,0 +1,62 @@ +from dataclasses import asdict, dataclass + + +@dataclass +class CKTileGemmOperation: + layout_a: str + layout_b: str + layout_c: str + + datatype_a: str + datatype_b: str + datatype_c: str + + tile_m: int + tile_n: int + tile_k: int + + warp_m: int + warp_n: int + warp_k: int + + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + + m_is_padded: str + n_is_padded: str + k_is_padded: str + + pipeline: str + scheduler: str + epilogue: str + + def layout_repr(self): + return f"{self.layout_a[0]}{self.layout_b[0]}{self.layout_c[0]}" + + def dtype_repr(self): + return f"{self.datatype_a}{self.datatype_b}{self.datatype_c}" + + def tile_sizes(self): + return "_".join( + [ + f"{self.tile_m}{self.tile_n}{self.tile_k}", + f"{self.warp_m}{self.warp_n}{self.warp_k}", + f"{self.warp_tile_m}{self.warp_tile_n}{self.warp_tile_k}", + ] + ) + + def name(self): + return "ck_tile_gemm_universal_" + "_".join( + [ + f"{self.layout_repr()}", + f"{self.dtype_repr()}", + f"{self.tile_sizes()}", + f"{self.pipeline}", + f"{self.scheduler}", + f"{self.epilogue}", + ] + ) + + def dict_items(self): + return asdict(self).items() diff --git a/python/test/test_gen_instances.py b/python/test/test_gen_instances.py index 4a85c702f9c..561a3073fe4 100644 --- a/python/test/test_gen_instances.py +++ b/python/test/test_gen_instances.py @@ -16,6 +16,9 @@ from ck4inductor.batched_universal_gemm.gen_instances import ( gen_ops_library as gen_batched_gemm_ops_library, ) +from ck4inductor.ck_tile_universal_gemm.gen_instances import ( + ops as gen_ck_tile_gemm_ops_library +) log = logging.getLogger(__name__) @@ -44,3 +47,9 @@ def test_gen_batched_gemm_instances(self): log.debug("%d gemm instances from library" % len(instances)) self.assertTrue(instances) + + def test_gen_ck_tile_universal_gemm_instances(self): + instances = gen_ck_tile_gemm_ops_library() + + log.debug("%d ck-tile gemm instances from library" % len(instances)) + self.assertTrue(instances)