From 6672cd0ff28bcf08b5b04122343acab4252bd58a Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Thu, 15 May 2025 18:07:42 +0000 Subject: [PATCH] Merge commit '3d8d6e75e485f5811df0ca37272f119392727726' into develop --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 26 ++--- .../warp/warp_gemm_attribute_mfma_impl.hpp | 8 +- .../gemm/configs/instance_combination.json | 4 +- tile_engine/ops/gemm/gemm_instance_builder.py | 96 +++++++++++++++---- 4 files changed, 96 insertions(+), 38 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index f050a8e382..be5d5690ff 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -204,14 +204,6 @@ using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl>>; -using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, - 2>>; - -using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, - 2>>; - using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; @@ -221,20 +213,28 @@ using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; -using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl, +using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl, + 2>>; + +using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl, 2>>; using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl< WarpGemmAtrributeMfma>>; +using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl< + WarpGemmAtrributeMfma>>; + +using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl, + 2>>; + using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl, 2>>; -using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl< - WarpGemmAtrributeMfma>>; - using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp index 69d22496f1..4bc4884beb 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp @@ -1092,7 +1092,7 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base } else { -#if defined(__gfx94__) +#if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); @@ -1116,7 +1116,7 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx94__) +#if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); @@ -1251,7 +1251,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base } else { -#if defined(__gfx94__) +#if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), c_vec, 0, 0, 0); @@ -1286,7 +1286,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base // c_vec = a_vec * b_vec CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const { -#if defined(__gfx94__) +#if defined(__gfx94__) or defined(__gfx95__) if constexpr(std::is_same_v && std::is_same_v) return bit_cast(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( bit_cast(a_vec), bit_cast(b_vec), CVecType{0.f}, 0, 0, 0)); diff --git a/tile_engine/ops/gemm/configs/instance_combination.json b/tile_engine/ops/gemm/configs/instance_combination.json index 53197ada6c..b497513efa 100644 --- a/tile_engine/ops/gemm/configs/instance_combination.json +++ b/tile_engine/ops/gemm/configs/instance_combination.json @@ -1,5 +1,7 @@ { - + "architecture": { + "values": ["gfx90a"] + }, "layout_a": { "values": ["r"] }, diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 3839523e3d..dd8b4d1157 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -23,7 +23,39 @@ DATA_TYPE_MAP = {'fp32' : 'float', } LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor', - 'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'} + 'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'} + + +warp_tile_combinations_map = { + "gfx90a": { + 'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'fp8': [[32, 32, 16], [32, 32, 32]], + 'bf8': [[32, 32, 16], [32, 32, 32]] + }, + "gfx942": { + 'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'fp8': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], + 'bf8': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]] + }, + "gfx950": { + 'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]], + 'fp8': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]], + 'bf8': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]] + } + } + +def sizeOf(data_type): + if data_type == 'fp16' or data_type == 'bf16': + return 2 + elif data_type == 'int8' or data_type == 'fp8' or data_type == 'bf8': + return 1 + elif data_type == 'int4': ## TODO:: needs to confirm + return 0.5 + else: + return 4 DEFAULT_EPILOGUE = """ using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue< @@ -168,11 +200,15 @@ class GemmConfig: self.matrix_cfg : Dict[str, Any] = {} self.impl_cfg : Dict[str, Any] = {} for key, value in config_data.items(): - if key in ["datatype", "layout_a", "layout_b", "layout_c"]: + if key in ["architecture", "datatype", "layout_a", "layout_b", "layout_c"]: self.matrix_cfg[key] = value else: self.impl_cfg[key] = value + @property + def architecture(self) -> str: + return self.matrix_cfg["architecture"]["values"][0] + @property def datatype(self) -> str: return self.matrix_cfg["datatype"]["values"][0] @@ -201,7 +237,7 @@ class GemmCodeGenerator: def _validate_config(self): """Validate matrix and implementation configurations""" # Matrix config validation - for param in ["datatype", "layout_a", "layout_b", "layout_c"]: + for param in ["architecture", "datatype", "layout_a", "layout_b", "layout_c"]: if len(self.config.matrix_cfg[param]["values"]) != 1: raise ValueError(f"Matrix config {param} must have exactly one value") @@ -327,7 +363,7 @@ namespace {group_name} {{ return f""" template void try_run(ck_tile::TailNumber tn) {{ - if constexpr (Pipeline::PrefetchStages > static_cast(TN)) {{ + if constexpr (Pipeline::PrefetchStages > static_cast(TN) - 1) {{ if (tn == TN) {{ RunSplitk(ck_tile::bool_constant{{}}, ck_tile::integral_constant{{}}); @@ -477,6 +513,30 @@ struct GemmKernel {{ content += f"#include \"gemm_{group}.hpp\"\n" (self.output_dir / "gemm_instances.hpp").write_text(content) + def is_tile_valid(self, tile: tuple, group: str) -> bool: + """Check if the tile configuration is valid for the given group""" + # Extract tile parameters + tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile + + # Extract the pipeline and epilogue from the group name + _, pipeline, epilogue, scheduler, *_ = group.split("_") + + if tile_m % (warp_m * warp_tile_m) == 0 and \ + tile_n % (warp_n * warp_tile_n) == 0 and \ + tile_k % (warp_k * warp_tile_k) == 0: + total_tile_in_lds = (tile_m * tile_k + tile_n * tile_k ) * sizeOf(self.config.datatype) + # Validate and append valid tile parameters + is_compv4 = pipeline == "compv4" + max_tile_size = pow(2, 16) if is_compv4 else pow(2, 15) + + if total_tile_in_lds > max_tile_size: + raise ValueError(f'Total tile size should not exceed {max_tile_size / 1024}KB of LDS. ' + f'{tile_m} * {tile_n} * {tile_k} > {max_tile_size / 1024}KB') + arch = self.config.architecture + if [warp_tile_m, warp_tile_n, warp_tile_k] in warp_tile_combinations_map[arch][self.config.datatype]: + return True + return False + def _generate_dispatcher(self): """Generate dispatch mechanism""" content = """// SPDX-License-Identifier: MIT @@ -517,7 +577,7 @@ struct GemmDispatcher { self.config.impl_cfg["warp_tile_k"]["values"] )) - + for group in self.all_kernels: content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_host_result, @@ -526,26 +586,22 @@ struct GemmDispatcher { const ck_tile::stream_config& stream) {{ if(structured_sparsity){{ // SMFMA""" for tile in tile_params: - # Check if we have valid tile/warp combinations - # (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m - if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \ - ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): - continue - sparse = self.atype == 'fp16' and \ - ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or - (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) - content += f""" + if self.is_tile_valid(tile, group): + sparse = self.atype == 'fp16' and \ + ((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or + (tile[6] == 16 and tile[7] == 16 and tile[8] == 32)) + content += f""" run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" + else: + raise ValueError(f"Invalid tile configuration for group {group}: {tile}") content += f""" }} else {{""" for tile in tile_params: - # Check if we have valid tile/warp combinations - # (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m - if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \ - ((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]): - continue - content += f""" + if self.is_tile_valid(tile, group): + content += f""" run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);""" + else: + raise ValueError(f"Invalid tile configuration for group {group}: {tile}") content += f""" }} }};\n"""