mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[rocm-libraries] ROCm/rocm-libraries#4594 (commit 1fce4cb)
[CK_TILE] MX GEMM non-preshuffled RCR layout ## Motivation Implements a GEMM with MX scaling for fp4 and fp8 in non-preshuffled layouts using async pipeline. ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b8def2c724
commit
8f27f65d44
@@ -141,8 +141,11 @@ struct GemmPipelineAgBgCrImplBase
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
// TODO: LDS alignment should come from Policy!
|
||||
constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple(
|
||||
sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size(), 16);
|
||||
constexpr index_t APackedSize = numeric_traits<OverrideADataType>::PackedSize;
|
||||
constexpr index_t a_lds_block_space_size =
|
||||
sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size() / APackedSize;
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_least_multiple(a_lds_block_space_size, 16);
|
||||
|
||||
// B tile in LDS
|
||||
OverrideBDataType* __restrict__ p_b_lds = static_cast<OverrideBDataType*>(
|
||||
|
||||
@@ -89,6 +89,8 @@ struct BaseGemmPipelineAgBgCrCompAsync
|
||||
"Invalid TailNumber: Only TailNumber::Three and TailNumber::Two are supported");
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GetName() { return "COMPUTE_ASYNC"; }
|
||||
};
|
||||
|
||||
/**
|
||||
|
||||
@@ -110,7 +110,7 @@ struct GemmPipelineProblemBase
|
||||
}
|
||||
else
|
||||
{
|
||||
return VectorLoadSize / sizeof(ADataType);
|
||||
return PackedSize * VectorLoadSize / sizeof(ADataType);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -536,14 +536,8 @@ struct UniversalGemmBasePolicy
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
|
||||
// Assume DataType is even!
|
||||
if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
|
||||
PackedSize == 2)
|
||||
{
|
||||
return (PackedSize * 32 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
|
||||
if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 16 / sizeof(DataType));
|
||||
}
|
||||
@@ -861,30 +855,32 @@ struct UniversalGemmBasePolicy
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_a = integer_least_multiple(
|
||||
a_lds_block_desc.get_element_space_size() * sizeof(ADataType), 16);
|
||||
a_lds_block_desc.get_element_space_size() * sizeof(ADataType) / APackedSize, 16);
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
{
|
||||
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
|
||||
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_b = integer_least_multiple(
|
||||
b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16);
|
||||
b_lds_block_desc.get_element_space_size() * sizeof(BDataType) / BPackedSize, 16);
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
|
||||
|
||||
Reference in New Issue
Block a user