mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[rocm-libraries] ROCm/rocm-libraries#5095 (commit 7e55766)
[CK_TILE] Enable MXFP6 for MX GEMM op ## Motivation Add support for MXFP6 in the MX GEMM op in CK-Tile. Depends on https://github.com/ROCm/rocm-libraries/pull/4594 ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
a5d0200ccf
commit
d7c761e060
@@ -146,9 +146,10 @@ struct GemmPipelineAgBgCrImplBase
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
// TODO: LDS alignment should come from Policy!
|
||||
constexpr index_t APackedSize = numeric_traits<OverrideADataType>::PackedSize;
|
||||
constexpr index_t a_lds_block_space_size =
|
||||
sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size() / APackedSize;
|
||||
constexpr index_t APackedSize = numeric_traits<OverrideADataType>::PackedSize;
|
||||
constexpr index_t a_lds_block_space_size = lds_padded_sizeof<OverrideADataType>() *
|
||||
a_lds_block_desc.get_element_space_size() /
|
||||
APackedSize;
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_least_multiple(a_lds_block_space_size, 16);
|
||||
|
||||
|
||||
@@ -837,9 +837,10 @@ struct UniversalGemmBasePolicy
|
||||
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
|
||||
|
||||
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A));
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A)) *
|
||||
numeric_traits<A>::PackedSize;
|
||||
|
||||
return (KPack < VecElems) ? KPack : VecElems;
|
||||
return ck_tile::min(KPack, VecElems);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -849,9 +850,10 @@ struct UniversalGemmBasePolicy
|
||||
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
|
||||
|
||||
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(B));
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(B)) *
|
||||
numeric_traits<B>::PackedSize;
|
||||
|
||||
return (KPack < VecElems) ? KPack : VecElems;
|
||||
return ck_tile::min(KPack, VecElems);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -860,8 +862,10 @@ struct UniversalGemmBasePolicy
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_a = integer_least_multiple(
|
||||
a_lds_block_desc.get_element_space_size() * sizeof(ADataType) / APackedSize, 16);
|
||||
constexpr index_t smem_size_a =
|
||||
integer_least_multiple(a_lds_block_desc.get_element_space_size() *
|
||||
lds_padded_sizeof<ADataType>() / APackedSize,
|
||||
16);
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
@@ -874,8 +878,10 @@ struct UniversalGemmBasePolicy
|
||||
typename Problem::BDataType>;
|
||||
constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_b = integer_least_multiple(
|
||||
b_lds_block_desc.get_element_space_size() * sizeof(BDataType) / BPackedSize, 16);
|
||||
constexpr index_t smem_size_b =
|
||||
integer_least_multiple(b_lds_block_desc.get_element_space_size() *
|
||||
lds_padded_sizeof<BDataType>() / BPackedSize,
|
||||
16);
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user