Fixing GEMM Multi D on Tile Engine (#3583)

This commit is contained in:
Thrupti Raj Lakshmana Gowda
2026-01-16 12:17:21 -06:00
committed by GitHub
parent 644cdbe3c9
commit de8ee379ad

View File

@@ -676,36 +676,38 @@ struct SelectedKernel {{
if self.kernel_name_prefix == "gemm_multi_d":
instance_code += """
// Kernel type
using GemmKernelMultiD = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
// Kernel arguments
auto kargs = GemmKernelMultiD::MakeKernelArgs(args);
if (!GemmKernelMultiD::IsSupportedArgument(kargs)) {
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}
// Kernel type
using GemmKernelMultiD = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
// Kernel arguments
auto kargs = GemmKernelMultiD::MakeKernelArgs(args);
if (!GemmKernelMultiD::IsSupportedArgument(kargs)) {
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}
// Get grid and block sizes
const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = GemmKernelMultiD::BlockSize();
if(stream.log_level_ > 0) {
std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}"""
// Get grid and block sizes
const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = GemmKernelMultiD::BlockSize();
if(stream.log_level_ > 0) {
std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}"""
instance_code += f"""
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
float ave_time = ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernelMultiD{{}}, grids, blocks, 0, kargs));
return ave_time;
}};"""
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
float ave_time = ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernelMultiD{{}}, grids, blocks, 0, kargs));
return ave_time;
}}
}};
"""
elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
instance_code += f"""
@@ -713,32 +715,32 @@ struct SelectedKernel {{
// Kernel type
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
// Kernel arguments
auto kargs = GemmKernel::MakeKernelArgs(args);
if (!GemmKernel::IsSupportedArgument(kargs)) {{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}}
// Kernel arguments
auto kargs = GemmKernel::MakeKernelArgs(args);
if (!GemmKernel::IsSupportedArgument(kargs)) {{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}}
// Get grid and block sizes
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
const dim3 blocks = GemmKernel::BlockSize();
if(stream.log_level_ > 0) {{
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n'
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
<< std::endl;
}}"""
// Get grid and block sizes
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
const dim3 blocks = GemmKernel::BlockSize();
if(stream.log_level_ > 0) {{
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n'
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
<< std::endl;
}}"""
instance_code += f"""
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
float ave_time = ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
return ave_time;
// Launch kernel
constexpr int kBlockPerCu = {k_block_per_cu};
float ave_time = ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
return ave_time;
}}
}};
"""
@@ -747,8 +749,8 @@ struct SelectedKernel {{
def populate_epilogue(self, epilogue):
instance_code = """
// Epilogue
"""
// Epilogue
"""
if epilogue == "cshuffle":
if self.kernel_name_prefix == "gemm_universal":
@@ -769,145 +771,145 @@ struct SelectedKernel {{
def populate_cshuffle_gemm_universal(self):
instance_code = """
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TileM, // kM_
TileN, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC, // isCTransposed_
NumWaveGroups>; // kNumWaveGroups_
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TileM, // kM_
TileN, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC, // isCTransposed_
NumWaveGroups>; // kNumWaveGroups_
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
return instance_code
def populate_cshuffle_gemm_multi_d(self):
instance_code = """
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ElementWiseFn,
TileM, // kM_
TileN, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC>; // isCTransposed_
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ElementWiseFn,
TileM, // kM_
TileN, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC>; // isCTransposed_
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
return instance_code
def populate_cshuffle_gemm_preshuffle(self):
instance_code = """
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TileM, // kM_
TileN, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC, // isCTransposed_
NumWaveGroups, // kNumWaveGroups_
false, // FixedVectorSize_
1, // VectorSizeC_
PermuteN>; // isPermuteN_
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TileM, // kM_
TileN, // kN_
WarpPerBlock_M, // MWave_
WarpPerBlock_N, // NWave_
WarpTileM, // MPerXdl_
WarpTileN, // NPerXdl_
WarpTileK, // KPerXdl_
TransposeC, // isCTransposed_
NumWaveGroups, // kNumWaveGroups_
false, // FixedVectorSize_
1, // VectorSizeC_
PermuteN>; // isPermuteN_
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
return instance_code
def populate_default_gemm_universal(self):
instance_code = """
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TileM, // kM_
TileN, // kN_
kPadM,
kPadN,
WarpTileM, // kMPerXdl_
WarpTileN, // kNPerXdl_
WarpTileK, // kKPerXdl_
TransposeC>; // isCTransposed_
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TileM, // kM_
TileN, // kN_
kPadM,
kPadN,
WarpTileM, // kMPerXdl_
WarpTileN, // kNPerXdl_
WarpTileK, // kKPerXdl_
TransposeC>; // isCTransposed_
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
return instance_code
def populate_default_gemm_multi_d(self):
instance_code = """
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ElementWiseFn,
TileM, // kM_
TileN, // kN_
kPadM,
kPadN,
WarpTileM, // kMPerXdl_
WarpTileN, // kNPerXdl_
WarpTileK, // kKPerXdl_
TransposeC>; // isCTransposed_
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ElementWiseFn,
TileM, // kM_
TileN, // kN_
kPadM,
kPadN,
WarpTileM, // kMPerXdl_
WarpTileN, // kNPerXdl_
WarpTileK, // kKPerXdl_
TransposeC>; // isCTransposed_
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
return instance_code
def populate_default_gemm_preshuffle(self):
instance_code = """
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TileM, // kM_
TileN, // kN_
kPadM,
kPadN,
WarpTileM, // kMPerXdl_
WarpTileN, // kNPerXdl_
WarpTileK, // kKPerXdl_
TransposeC>; // isCTransposed_
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout
CLayout,
ck_tile::element_wise::PassThrough,
TileM, // kM_
TileN, // kN_
kPadM,
kPadN,
WarpTileM, // kMPerXdl_
WarpTileN, // kNPerXdl_
WarpTileK, // kKPerXdl_
TransposeC>; // isCTransposed_
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
return instance_code
def _generate_cmake_individual_targets(self, kernel_list):