mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Add other layouts for FP8 block scaled gemm (#2665)
* Start adding other layouts for gemm_ab_scale
* Add some instances
* Create tensor descriptors for A/B scales depending on A/B layout
* Fix formatting
* Revert some comments
* Revert commented instances in CMakeLists.txt
* Add some more instances for col-row gemm
* enable more row,row instances
* Use occupancy=1 for col,row layout to avoid spills
[ROCm/composable_kernel commit: 26d3300930]
This commit is contained in:
@@ -231,11 +231,22 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
(BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave &&
|
||||
MPerBlock * NPerBlock / BlockSize > 64)
|
||||
? 1
|
||||
: 2;
|
||||
constexpr index_t minimum_occupancy = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout> &&
|
||||
is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
// FIXME: many instances have many spills with occupancy > 1, a better solution
|
||||
// needed to get best performance
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave &&
|
||||
MPerBlock * NPerBlock / BlockSize > 64)
|
||||
? 1
|
||||
: 2;
|
||||
}
|
||||
}();
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
|
||||
@@ -402,6 +402,34 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeAScaleGridDesciptor_M_K(index_t M, index_t K)
|
||||
{
|
||||
const auto BM = math::integer_divide_ceil(M, ScaleBlockM);
|
||||
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(BK, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, BM));
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBScaleGridDesciptor_N_K(index_t N, index_t K)
|
||||
{
|
||||
const auto BN = math::integer_divide_ceil(N, ScaleBlockN);
|
||||
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
|
||||
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(BK, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, BN));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ABlockDesc_AK0_M_AK1>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
|
||||
@@ -1181,14 +1209,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM),
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
|
||||
const auto a_scale_grid_desc_am_ak = MakeAScaleGridDesciptor_M_K(problem.M, problem.K);
|
||||
const auto b_scale_grid_desc_bn_ak = MakeBScaleGridDesciptor_N_K(problem.N, problem.K);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
|
||||
Reference in New Issue
Block a user