mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Fixing GEMM Multi D on Tile Engine (#3583)
This commit is contained in:
committed by
GitHub
parent
644cdbe3c9
commit
de8ee379ad
@@ -676,36 +676,38 @@ struct SelectedKernel {{
|
||||
if self.kernel_name_prefix == "gemm_multi_d":
|
||||
instance_code += """
|
||||
|
||||
// Kernel type
|
||||
using GemmKernelMultiD = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
// Kernel arguments
|
||||
auto kargs = GemmKernelMultiD::MakeKernelArgs(args);
|
||||
|
||||
if (!GemmKernelMultiD::IsSupportedArgument(kargs)) {
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
||||
}
|
||||
// Kernel type
|
||||
using GemmKernelMultiD = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
// Kernel arguments
|
||||
auto kargs = GemmKernelMultiD::MakeKernelArgs(args);
|
||||
|
||||
if (!GemmKernelMultiD::IsSupportedArgument(kargs)) {
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
||||
}
|
||||
|
||||
// Get grid and block sizes
|
||||
const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = GemmKernelMultiD::BlockSize();
|
||||
|
||||
if(stream.log_level_ > 0) {
|
||||
std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}"""
|
||||
// Get grid and block sizes
|
||||
const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = GemmKernelMultiD::BlockSize();
|
||||
|
||||
if(stream.log_level_ > 0) {
|
||||
std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}"""
|
||||
|
||||
instance_code += f"""
|
||||
// Launch kernel
|
||||
constexpr int kBlockPerCu = {k_block_per_cu};
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
stream,
|
||||
ck_tile::make_kernel<kBlockPerCu>(GemmKernelMultiD{{}}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}};"""
|
||||
// Launch kernel
|
||||
constexpr int kBlockPerCu = {k_block_per_cu};
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
stream,
|
||||
ck_tile::make_kernel<kBlockPerCu>(GemmKernelMultiD{{}}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
|
||||
instance_code += f"""
|
||||
@@ -713,32 +715,32 @@ struct SelectedKernel {{
|
||||
// Kernel type
|
||||
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
// Kernel arguments
|
||||
auto kargs = GemmKernel::MakeKernelArgs(args);
|
||||
|
||||
if (!GemmKernel::IsSupportedArgument(kargs)) {{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
||||
}}
|
||||
// Kernel arguments
|
||||
auto kargs = GemmKernel::MakeKernelArgs(args);
|
||||
|
||||
if (!GemmKernel::IsSupportedArgument(kargs)) {{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
||||
}}
|
||||
|
||||
// Get grid and block sizes
|
||||
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
|
||||
const dim3 blocks = GemmKernel::BlockSize();
|
||||
|
||||
if(stream.log_level_ > 0) {{
|
||||
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n'
|
||||
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
||||
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
|
||||
<< std::endl;
|
||||
}}"""
|
||||
// Get grid and block sizes
|
||||
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
|
||||
const dim3 blocks = GemmKernel::BlockSize();
|
||||
|
||||
if(stream.log_level_ > 0) {{
|
||||
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n'
|
||||
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
||||
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
|
||||
<< std::endl;
|
||||
}}"""
|
||||
|
||||
instance_code += f"""
|
||||
// Launch kernel
|
||||
constexpr int kBlockPerCu = {k_block_per_cu};
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
stream,
|
||||
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
// Launch kernel
|
||||
constexpr int kBlockPerCu = {k_block_per_cu};
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
stream,
|
||||
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
@@ -747,8 +749,8 @@ struct SelectedKernel {{
|
||||
def populate_epilogue(self, epilogue):
|
||||
instance_code = """
|
||||
|
||||
// Epilogue
|
||||
"""
|
||||
// Epilogue
|
||||
"""
|
||||
|
||||
if epilogue == "cshuffle":
|
||||
if self.kernel_name_prefix == "gemm_universal":
|
||||
@@ -769,145 +771,145 @@ struct SelectedKernel {{
|
||||
|
||||
def populate_cshuffle_gemm_universal(self):
|
||||
instance_code = """
|
||||
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TileM, // kM_
|
||||
TileN, // kN_
|
||||
WarpPerBlock_M, // MWave_
|
||||
WarpPerBlock_N, // NWave_
|
||||
WarpTileM, // MPerXdl_
|
||||
WarpTileN, // NPerXdl_
|
||||
WarpTileK, // KPerXdl_
|
||||
TransposeC, // isCTransposed_
|
||||
NumWaveGroups>; // kNumWaveGroups_
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
|
||||
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TileM, // kM_
|
||||
TileN, // kN_
|
||||
WarpPerBlock_M, // MWave_
|
||||
WarpPerBlock_N, // NWave_
|
||||
WarpTileM, // MPerXdl_
|
||||
WarpTileN, // NPerXdl_
|
||||
WarpTileK, // KPerXdl_
|
||||
TransposeC, // isCTransposed_
|
||||
NumWaveGroups>; // kNumWaveGroups_
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
|
||||
return instance_code
|
||||
|
||||
def populate_cshuffle_gemm_multi_d(self):
|
||||
instance_code = """
|
||||
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ElementWiseFn,
|
||||
TileM, // kM_
|
||||
TileN, // kN_
|
||||
WarpPerBlock_M, // MWave_
|
||||
WarpPerBlock_N, // NWave_
|
||||
WarpTileM, // MPerXdl_
|
||||
WarpTileN, // NPerXdl_
|
||||
WarpTileK, // KPerXdl_
|
||||
TransposeC>; // isCTransposed_
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
|
||||
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ElementWiseFn,
|
||||
TileM, // kM_
|
||||
TileN, // kN_
|
||||
WarpPerBlock_M, // MWave_
|
||||
WarpPerBlock_N, // NWave_
|
||||
WarpTileM, // MPerXdl_
|
||||
WarpTileN, // NPerXdl_
|
||||
WarpTileK, // KPerXdl_
|
||||
TransposeC>; // isCTransposed_
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
|
||||
return instance_code
|
||||
|
||||
def populate_cshuffle_gemm_preshuffle(self):
|
||||
instance_code = """
|
||||
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TileM, // kM_
|
||||
TileN, // kN_
|
||||
WarpPerBlock_M, // MWave_
|
||||
WarpPerBlock_N, // NWave_
|
||||
WarpTileM, // MPerXdl_
|
||||
WarpTileN, // NPerXdl_
|
||||
WarpTileK, // KPerXdl_
|
||||
TransposeC, // isCTransposed_
|
||||
NumWaveGroups, // kNumWaveGroups_
|
||||
false, // FixedVectorSize_
|
||||
1, // VectorSizeC_
|
||||
PermuteN>; // isPermuteN_
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
|
||||
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TileM, // kM_
|
||||
TileN, // kN_
|
||||
WarpPerBlock_M, // MWave_
|
||||
WarpPerBlock_N, // NWave_
|
||||
WarpTileM, // MPerXdl_
|
||||
WarpTileN, // NPerXdl_
|
||||
WarpTileK, // KPerXdl_
|
||||
TransposeC, // isCTransposed_
|
||||
NumWaveGroups, // kNumWaveGroups_
|
||||
false, // FixedVectorSize_
|
||||
1, // VectorSizeC_
|
||||
PermuteN>; // isPermuteN_
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
|
||||
return instance_code
|
||||
|
||||
def populate_default_gemm_universal(self):
|
||||
instance_code = """
|
||||
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TileM, // kM_
|
||||
TileN, // kN_
|
||||
kPadM,
|
||||
kPadN,
|
||||
WarpTileM, // kMPerXdl_
|
||||
WarpTileN, // kNPerXdl_
|
||||
WarpTileK, // kKPerXdl_
|
||||
TransposeC>; // isCTransposed_
|
||||
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
|
||||
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TileM, // kM_
|
||||
TileN, // kN_
|
||||
kPadM,
|
||||
kPadN,
|
||||
WarpTileM, // kMPerXdl_
|
||||
WarpTileN, // kNPerXdl_
|
||||
WarpTileK, // kKPerXdl_
|
||||
TransposeC>; // isCTransposed_
|
||||
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
|
||||
return instance_code
|
||||
|
||||
def populate_default_gemm_multi_d(self):
|
||||
instance_code = """
|
||||
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ElementWiseFn,
|
||||
TileM, // kM_
|
||||
TileN, // kN_
|
||||
kPadM,
|
||||
kPadN,
|
||||
WarpTileM, // kMPerXdl_
|
||||
WarpTileN, // kNPerXdl_
|
||||
WarpTileK, // kKPerXdl_
|
||||
TransposeC>; // isCTransposed_
|
||||
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
|
||||
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ElementWiseFn,
|
||||
TileM, // kM_
|
||||
TileN, // kN_
|
||||
kPadM,
|
||||
kPadN,
|
||||
WarpTileM, // kMPerXdl_
|
||||
WarpTileN, // kNPerXdl_
|
||||
WarpTileK, // kKPerXdl_
|
||||
TransposeC>; // isCTransposed_
|
||||
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
|
||||
return instance_code
|
||||
|
||||
def populate_default_gemm_preshuffle(self):
|
||||
instance_code = """
|
||||
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TileM, // kM_
|
||||
TileN, // kN_
|
||||
kPadM,
|
||||
kPadN,
|
||||
WarpTileM, // kMPerXdl_
|
||||
WarpTileN, // kNPerXdl_
|
||||
WarpTileK, // kKPerXdl_
|
||||
TransposeC>; // isCTransposed_
|
||||
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
|
||||
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TileM, // kM_
|
||||
TileN, // kN_
|
||||
kPadM,
|
||||
kPadN,
|
||||
WarpTileM, // kMPerXdl_
|
||||
WarpTileN, // kNPerXdl_
|
||||
WarpTileK, // kKPerXdl_
|
||||
TransposeC>; // isCTransposed_
|
||||
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
|
||||
return instance_code
|
||||
|
||||
def _generate_cmake_individual_targets(self, kernel_list):
|
||||
|
||||
Reference in New Issue
Block a user