mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add other layouts for FP8 block scaled gemm (#2665)
* Start adding other layouts for gemm_ab_scale * Add some instances * Create tensor descriptors for A/B scales depending on A/B layout * Fix formatting * Revert some comments * Revert commented instances in CMakeLists.txt * Add some more instances for col-row gemm * enable more row,row instances * Use occupancy=1 for col,row layout to avoid spills
This commit is contained in:
@@ -402,6 +402,34 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeAScaleGridDesciptor_M_K(index_t M, index_t K)
|
||||
{
|
||||
const auto BM = math::integer_divide_ceil(M, ScaleBlockM);
|
||||
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(BK, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, BM));
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBScaleGridDesciptor_N_K(index_t N, index_t K)
|
||||
{
|
||||
const auto BN = math::integer_divide_ceil(N, ScaleBlockN);
|
||||
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
|
||||
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(BK, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, BN));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ABlockDesc_AK0_M_AK1>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
|
||||
@@ -1181,14 +1209,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM),
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
|
||||
const auto a_scale_grid_desc_am_ak = MakeAScaleGridDesciptor_M_K(problem.M, problem.K);
|
||||
const auto b_scale_grid_desc_bn_ak = MakeBScaleGridDesciptor_N_K(problem.N, problem.K);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
|
||||
Reference in New Issue
Block a user