[CK_TILE] Enable MXFP6 for MX GEMM op (#5095)

## Motivation

Add support for MXFP6 in the MX GEMM op in CK-Tile.

Depends on https://github.com/ROCm/rocm-libraries/pull/4594

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Sami Remes
2026-03-20 03:07:47 +02:00
committed by GitHub
parent a61238b69f
commit a699df9fdc
13 changed files with 160 additions and 31 deletions

View File

@@ -146,9 +146,10 @@ struct GemmPipelineAgBgCrImplBase
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// TODO: LDS alignment should come from Policy!
constexpr index_t APackedSize = numeric_traits<OverrideADataType>::PackedSize;
constexpr index_t a_lds_block_space_size =
sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size() / APackedSize;
constexpr index_t APackedSize = numeric_traits<OverrideADataType>::PackedSize;
constexpr index_t a_lds_block_space_size = lds_padded_sizeof<OverrideADataType>() *
a_lds_block_desc.get_element_space_size() /
APackedSize;
constexpr index_t a_lds_block_space_size_aligned =
integer_least_multiple(a_lds_block_space_size, 16);

View File

@@ -834,9 +834,10 @@ struct UniversalGemmBasePolicy
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A));
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A)) *
numeric_traits<A>::PackedSize;
return (KPack < VecElems) ? KPack : VecElems;
return ck_tile::min(KPack, VecElems);
}
template <typename Problem>
@@ -846,9 +847,10 @@ struct UniversalGemmBasePolicy
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(B));
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(B)) *
numeric_traits<B>::PackedSize;
return (KPack < VecElems) ? KPack : VecElems;
return ck_tile::min(KPack, VecElems);
}
template <typename Problem>
@@ -857,8 +859,10 @@ struct UniversalGemmBasePolicy
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor<Problem>();
constexpr index_t smem_size_a = integer_least_multiple(
a_lds_block_desc.get_element_space_size() * sizeof(ADataType) / APackedSize, 16);
constexpr index_t smem_size_a =
integer_least_multiple(a_lds_block_desc.get_element_space_size() *
lds_padded_sizeof<ADataType>() / APackedSize,
16);
return smem_size_a;
}
@@ -871,8 +875,10 @@ struct UniversalGemmBasePolicy
typename Problem::BDataType>;
constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
constexpr index_t smem_size_b = integer_least_multiple(
b_lds_block_desc.get_element_space_size() * sizeof(BDataType) / BPackedSize, 16);
constexpr index_t smem_size_b =
integer_least_multiple(b_lds_block_desc.get_element_space_size() *
lds_padded_sizeof<BDataType>() / BPackedSize,
16);
return smem_size_b;
}

View File

@@ -442,10 +442,12 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
MWarp / BlockSize,
"BLdsTile size is wrong!");
static_assert(Policy::template GetSmemSizeA<Problem>() ==
MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize),
MPerBlock *
(KPerBlock * lds_padded_sizeof<ADataType>() / APackedSize),
"SmemSizeA size is wrong!");
static_assert(Policy::template GetSmemSizeB<Problem>() ==
(KPerBlock * sizeof(BDataType) / BPackedSize) * NPerBlock,
(KPerBlock * lds_padded_sizeof<BDataType>() / BPackedSize) *
NPerBlock,
"SmemSizeB size is wrong!");
////////////// MX Scale register tiles (ping-pong buffers) /////////////////