Merge commit '3d8d6e75e485f5811df0ca37272f119392727726' into develop

This commit is contained in:
assistant-librarian[bot]
2025-05-15 18:07:42 +00:00
parent 0e4215ae3f
commit 6672cd0ff2
4 changed files with 96 additions and 38 deletions

View File

@@ -204,14 +204,6 @@ using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterat
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
@@ -221,20 +213,28 @@ using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
using WarpGemmMfma_f32_32x32x32_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_32x32x32_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x32_fp8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x64_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_16x16x32_fp8_fp8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfma_f32_16x16x32_bf8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x32_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl<WarpGemmAtrributeMfma<
WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8<WGAttrCtlEnum::Default_>>>;

View File

@@ -1092,7 +1092,7 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
}
else
{
#if defined(__gfx94__)
#if defined(__gfx94__) or defined(__gfx95__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
@@ -1116,7 +1116,7 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x32_f8_base
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx94__)
#if defined(__gfx94__) or defined(__gfx95__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
@@ -1251,7 +1251,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
}
else
{
#if defined(__gfx94__)
#if defined(__gfx94__) or defined(__gfx95__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
@@ -1286,7 +1286,7 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx94__)
#if defined(__gfx94__) or defined(__gfx95__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));

View File

@@ -1,5 +1,7 @@
{
"architecture": {
"values": ["gfx90a"]
},
"layout_a": {
"values": ["r"]
},

View File

@@ -23,7 +23,39 @@ DATA_TYPE_MAP = {'fp32' : 'float',
}
LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor',
'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'}
'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'}
warp_tile_combinations_map = {
"gfx90a": {
'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'fp8': [[32, 32, 16], [32, 32, 32]],
'bf8': [[32, 32, 16], [32, 32, 32]]
},
"gfx942": {
'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'fp8': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
'bf8': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]]
},
"gfx950": {
'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'fp8': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
'bf8': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
}
}
def sizeOf(data_type):
if data_type == 'fp16' or data_type == 'bf16':
return 2
elif data_type == 'int8' or data_type == 'fp8' or data_type == 'bf8':
return 1
elif data_type == 'int4': ## TODO:: needs to confirm
return 0.5
else:
return 4
DEFAULT_EPILOGUE = """
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
@@ -168,11 +200,15 @@ class GemmConfig:
self.matrix_cfg : Dict[str, Any] = {}
self.impl_cfg : Dict[str, Any] = {}
for key, value in config_data.items():
if key in ["datatype", "layout_a", "layout_b", "layout_c"]:
if key in ["architecture", "datatype", "layout_a", "layout_b", "layout_c"]:
self.matrix_cfg[key] = value
else:
self.impl_cfg[key] = value
@property
def architecture(self) -> str:
return self.matrix_cfg["architecture"]["values"][0]
@property
def datatype(self) -> str:
return self.matrix_cfg["datatype"]["values"][0]
@@ -201,7 +237,7 @@ class GemmCodeGenerator:
def _validate_config(self):
"""Validate matrix and implementation configurations"""
# Matrix config validation
for param in ["datatype", "layout_a", "layout_b", "layout_c"]:
for param in ["architecture", "datatype", "layout_a", "layout_b", "layout_c"]:
if len(self.config.matrix_cfg[param]["values"]) != 1:
raise ValueError(f"Matrix config {param} must have exactly one value")
@@ -327,7 +363,7 @@ namespace {group_name} {{
return f"""
template <typename Pipeline, ck_tile::TailNumber TN>
void try_run(ck_tile::TailNumber tn) {{
if constexpr (Pipeline::PrefetchStages > static_cast<int>(TN)) {{
if constexpr (Pipeline::PrefetchStages > static_cast<int>(TN) - 1) {{
if (tn == TN) {{
RunSplitk(ck_tile::bool_constant<true>{{}},
ck_tile::integral_constant<ck_tile::TailNumber, TN>{{}});
@@ -477,6 +513,30 @@ struct GemmKernel {{
content += f"#include \"gemm_{group}.hpp\"\n"
(self.output_dir / "gemm_instances.hpp").write_text(content)
def is_tile_valid(self, tile: tuple, group: str) -> bool:
"""Check if the tile configuration is valid for the given group"""
# Extract tile parameters
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile
# Extract the pipeline and epilogue from the group name
_, pipeline, epilogue, scheduler, *_ = group.split("_")
if tile_m % (warp_m * warp_tile_m) == 0 and \
tile_n % (warp_n * warp_tile_n) == 0 and \
tile_k % (warp_k * warp_tile_k) == 0:
total_tile_in_lds = (tile_m * tile_k + tile_n * tile_k ) * sizeOf(self.config.datatype)
# Validate and append valid tile parameters
is_compv4 = pipeline == "compv4"
max_tile_size = pow(2, 16) if is_compv4 else pow(2, 15)
if total_tile_in_lds > max_tile_size:
raise ValueError(f'Total tile size should not exceed {max_tile_size / 1024}KB of LDS. '
f'{tile_m} * {tile_n} * {tile_k} > {max_tile_size / 1024}KB')
arch = self.config.architecture
if [warp_tile_m, warp_tile_n, warp_tile_k] in warp_tile_combinations_map[arch][self.config.datatype]:
return True
return False
def _generate_dispatcher(self):
"""Generate dispatch mechanism"""
content = """// SPDX-License-Identifier: MIT
@@ -517,7 +577,7 @@ struct GemmDispatcher {
self.config.impl_cfg["warp_tile_k"]["values"]
))
for group in self.all_kernels:
content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
@@ -526,26 +586,22 @@ struct GemmDispatcher {
const ck_tile::stream_config& stream) {{
if(structured_sparsity){{ // SMFMA"""
for tile in tile_params:
# Check if we have valid tile/warp combinations
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
continue
sparse = self.atype == 'fp16' and \
((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or
(tile[6] == 16 and tile[7] == 16 and tile[8] == 32))
content += f"""
if self.is_tile_valid(tile, group):
sparse = self.atype == 'fp16' and \
((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or
(tile[6] == 16 and tile[7] == 16 and tile[8] == 32))
content += f"""
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
else:
raise ValueError(f"Invalid tile configuration for group {group}: {tile}")
content += f"""
}} else {{"""
for tile in tile_params:
# Check if we have valid tile/warp combinations
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
continue
content += f"""
if self.is_tile_valid(tile, group):
content += f"""
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
else:
raise ValueError(f"Invalid tile configuration for group {group}: {tile}")
content += f"""
}}
}};\n"""