mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_Tile] Fix gemm kernel for 4,64,16 and 64,4,16 warp tile sizes (#2262)
* debugging issue
* debugging issue
* debugging
* debugging
* reverting debugging code
* clang formatted
* updating default_config.json
* fix ci failure
* clang formatted
[ROCm/composable_kernel commit: 59a85cb4bc]
This commit is contained in:
@@ -121,8 +121,11 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
@@ -92,7 +92,20 @@ struct CShuffleEpilogue
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
|
||||
{
|
||||
constexpr index_t MaxVectorStoreSize = 16;
|
||||
return MaxVectorStoreSize / sizeof(ODataType);
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(kNPerIteration),
|
||||
static_cast<int>(MaxVectorStoreSize / sizeof(ODataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(kMPerIteration),
|
||||
static_cast<int>(MaxVectorStoreSize / sizeof(ODataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported CLayout!");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -149,7 +149,9 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
else
|
||||
{
|
||||
// In this case each thread has just a single item in Ndim
|
||||
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
|
||||
return (WG::WarpGemmAttribute::Impl::kCNLane *
|
||||
WG::WarpGemmAttribute::Impl::kBNBlock) /
|
||||
WG::kN;
|
||||
}
|
||||
}
|
||||
// M is contiguous dimension
|
||||
@@ -158,7 +160,9 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
if constexpr(isCTransposed)
|
||||
{
|
||||
// In this case each thread has just a single item in Mdim
|
||||
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
|
||||
return (WG::WarpGemmAttribute::Impl::kCNLane *
|
||||
WG::WarpGemmAttribute::Impl::kAMBlock) /
|
||||
WG::kN;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -167,20 +167,20 @@ def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)]
|
||||
# To Do: add some more supported combinations
|
||||
warp_tile_supported_combinations = {
|
||||
"gfx90a": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]]
|
||||
},
|
||||
"gfx942": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]]
|
||||
},
|
||||
"gfx950": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@
|
||||
"max": 512,
|
||||
"min": 64,
|
||||
"step": 64,
|
||||
"exclude": []
|
||||
"exclude": [192]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
@@ -71,14 +71,20 @@
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32
|
||||
32,
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
4,
|
||||
8,
|
||||
16,
|
||||
32
|
||||
32,
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
|
||||
@@ -420,12 +420,12 @@ struct GemmKernel {{
|
||||
|
||||
# LDS capacity verification
|
||||
matrix_a_size = (tile_m * tile_k) * \
|
||||
pow(2, element_size(self.config.problem.datatype_map['matrix_a']))
|
||||
element_size(self.config.problem.datatype_map['matrix_a'])
|
||||
matrix_b_size = (tile_n * tile_k) * \
|
||||
pow(2, element_size(self.config.problem.datatype_map['matrix_b']))
|
||||
element_size(self.config.problem.datatype_map['matrix_b'])
|
||||
total_tile_in_lds = matrix_a_size + matrix_b_size
|
||||
|
||||
max_tile_size = 2**16 if pipeline == "compv4" else 2**15
|
||||
max_tile_size = 2**15 if pipeline == "compv4" else 2**16
|
||||
if total_tile_in_lds > max_tile_size:
|
||||
logging.debug(
|
||||
f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds/1024:.1f}KB) > "
|
||||
@@ -493,6 +493,9 @@ struct GemmKernel {{
|
||||
for trait in self.valid_trait_names:
|
||||
tile_valid_params = list(
|
||||
filter(lambda t: self.is_tile_valid(t, trait), tile_params))
|
||||
|
||||
# if len(tile_valid_params) == 0:
|
||||
# raise RuntimeError(f"No valid kernel instance selected for trait: {trait}")
|
||||
if trait not in self.valid_trait_tile_combinations:
|
||||
self.valid_trait_tile_combinations[trait] = []
|
||||
self.valid_trait_tile_combinations[trait].append(tile_valid_params)
|
||||
|
||||
Reference in New Issue
Block a user