[rocm-libraries] ROCm/rocm-libraries#4594 (commit 1fce4cb)

[CK_TILE] MX GEMM non-preshuffled RCR layout

## Motivation

Implements a GEMM with MX scaling for fp4 and fp8 in non-preshuffled
layouts using async pipeline.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Sami Remes
2026-03-10 20:12:43 +00:00
committed by assistant-librarian[bot]
parent b8def2c724
commit 8f27f65d44
40 changed files with 2729 additions and 43 deletions

View File

@@ -249,6 +249,113 @@ struct BlockGemmARegBRegCRegV1
});
}
// C += A * B with MX scaling
// ScaleATensor: [MIterPerWarp, KIterPerWarp] -> int32_t
// ScaleBTensor: [NIterPerWarp, KIterPerWarp] -> int32_t
template <typename CBlockTensor,
typename ABlockTensor,
typename BBlockTensor,
typename ScaleATensor,
typename ScaleBTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor,
const ScaleATensor& scale_a_tensor,
const ScaleBTensor& scale_b_tensor) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
// check ABC-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(MakeABlockDistributionEncode())>,
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"A distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(MakeBBlockDistributionEncode())>,
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"B distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop with MX scaling:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// get A scale for this M-K tile using get_y_sliced_thread_data
auto scale_a_slice = scale_a_tensor.get_y_sliced_thread_data(
sequence<kIter, mIter, 0>{}, sequence<1, 1, 1>{});
const auto a_scale_e8m0 = scale_a_slice[number<0>{}];
const int32_t a_scale = static_cast<int32_t>(a_scale_e8m0.get());
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// get B scale for this N-K tile using get_y_sliced_thread_data
auto scale_b_slice = scale_b_tensor.get_y_sliced_thread_data(
sequence<kIter, nIter, 0>{}, sequence<1, 1, 1>{});
const auto b_scale_e8m0 = scale_b_slice[number<0>{}];
const int32_t b_scale = static_cast<int32_t>(b_scale_e8m0.get());
// read C warp tensor from C block tensor
using c_iter_idx = std::
conditional_t<TransposeC, sequence<nIter, mIter>, sequence<mIter, nIter>>;
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM with MX scaling
// Cast e8m0_t to int32_t, use OpSel=0 (least significant byte)
constexpr index_t kOpSel = 0; // Always use OpSel=0
WarpGemm{}.template operator()<kOpSel, kOpSel>(
c_warp_tensor, a_warp_tensor, b_warp_tensor, a_scale, b_scale);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
using c_distr_ys_major = std::conditional_t<TransposeC, sequence<2, 1>, sequence<1, 2>>;

View File

@@ -141,8 +141,11 @@ struct GemmPipelineAgBgCrImplBase
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// TODO: LDS alignment should come from Policy!
constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple(
sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size(), 16);
constexpr index_t APackedSize = numeric_traits<OverrideADataType>::PackedSize;
constexpr index_t a_lds_block_space_size =
sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size() / APackedSize;
constexpr index_t a_lds_block_space_size_aligned =
integer_least_multiple(a_lds_block_space_size, 16);
// B tile in LDS
OverrideBDataType* __restrict__ p_b_lds = static_cast<OverrideBDataType*>(

View File

@@ -89,6 +89,8 @@ struct BaseGemmPipelineAgBgCrCompAsync
"Invalid TailNumber: Only TailNumber::Three and TailNumber::Two are supported");
#endif
}
CK_TILE_HOST static constexpr auto GetName() { return "COMPUTE_ASYNC"; }
};
/**

View File

@@ -110,7 +110,7 @@ struct GemmPipelineProblemBase
}
else
{
return VectorLoadSize / sizeof(ADataType);
return PackedSize * VectorLoadSize / sizeof(ADataType);
}
}

View File

@@ -536,14 +536,8 @@ struct UniversalGemmBasePolicy
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
// Assume DataType is even!
if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 &&
PackedSize == 2)
{
return (PackedSize * 32 / sizeof(DataType));
}
else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
{
return (PackedSize * 16 / sizeof(DataType));
}
@@ -861,30 +855,32 @@ struct UniversalGemmBasePolicy
}
template <typename Problem>
CK_TILE_DEVICE static constexpr index_t GetSmemSizeA()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor<Problem>();
constexpr index_t smem_size_a = integer_least_multiple(
a_lds_block_desc.get_element_space_size() * sizeof(ADataType), 16);
a_lds_block_desc.get_element_space_size() * sizeof(ADataType) / APackedSize, 16);
return smem_size_a;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr index_t GetSmemSizeB()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
constexpr index_t smem_size_b = integer_least_multiple(
b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16);
b_lds_block_desc.get_element_space_size() * sizeof(BDataType) / BPackedSize, 16);
return smem_size_b;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr index_t GetSmemSize()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();

View File

@@ -1599,6 +1599,9 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// To get unity scale: 2^(kDefaultScale - 127) = 1.0
static constexpr index_t kDefaultScale = 0x7F7F7F7F;
// c_vec += a_vec * b_vec
template <index_t opselA, index_t opselB, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
@@ -1669,13 +1672,13 @@ struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
operator()<0, 0>(c_vec, a_vec, 0, b_vec, 0);
operator()<0, 0>(c_vec, a_vec, kDefaultScale, b_vec, kDefaultScale);
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
return operator()<0, 0>(a_vec, 0, b_vec, 0);
return operator()<0, 0>(a_vec, kDefaultScale, b_vec, kDefaultScale);
}
};