[rocm-libraries] ROCm/rocm-libraries#4594 (commit 1fce4cb)

[CK_TILE] MX GEMM non-preshuffled RCR layout

## Motivation

Implements a GEMM with MX scaling for fp4 and fp8 in non-preshuffled
layouts using async pipeline.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Sami Remes
2026-03-10 20:12:43 +00:00
committed by assistant-librarian[bot]
parent b8def2c724
commit 8f27f65d44
40 changed files with 2729 additions and 43 deletions

View File

@@ -141,8 +141,11 @@ struct GemmPipelineAgBgCrImplBase
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// TODO: LDS alignment should come from Policy!
constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple(
sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size(), 16);
constexpr index_t APackedSize = numeric_traits<OverrideADataType>::PackedSize;
constexpr index_t a_lds_block_space_size =
sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size() / APackedSize;
constexpr index_t a_lds_block_space_size_aligned =
integer_least_multiple(a_lds_block_space_size, 16);
// B tile in LDS
OverrideBDataType* __restrict__ p_b_lds = static_cast<OverrideBDataType*>(

View File

@@ -89,6 +89,8 @@ struct BaseGemmPipelineAgBgCrCompAsync
"Invalid TailNumber: Only TailNumber::Three and TailNumber::Two are supported");
#endif
}
CK_TILE_HOST static constexpr auto GetName() { return "COMPUTE_ASYNC"; }
};
/**

View File

@@ -110,7 +110,7 @@ struct GemmPipelineProblemBase
}
else
{
return VectorLoadSize / sizeof(ADataType);
return PackedSize * VectorLoadSize / sizeof(ADataType);
}
}

View File

@@ -536,14 +536,8 @@ struct UniversalGemmBasePolicy
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
// Assume DataType is even!
if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
PackedSize == 2)
{
return (PackedSize * 32 / sizeof(DataType));
}
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
{
return (PackedSize * 16 / sizeof(DataType));
}
@@ -861,30 +855,32 @@ struct UniversalGemmBasePolicy
}
template <typename Problem>
CK_TILE_DEVICE static constexpr index_t GetSmemSizeA()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor<Problem>();
constexpr index_t smem_size_a = integer_least_multiple(
a_lds_block_desc.get_element_space_size() * sizeof(ADataType), 16);
a_lds_block_desc.get_element_space_size() * sizeof(ADataType) / APackedSize, 16);
return smem_size_a;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr index_t GetSmemSizeB()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
constexpr index_t smem_size_b = integer_least_multiple(
b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16);
b_lds_block_desc.get_element_space_size() * sizeof(BDataType) / BPackedSize, 16);
return smem_size_b;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();