Optimized GEMMs for MX FP4/8 (#2294)

Adds V3 GEMM pipeline for MX FP4 and MX FP8 
Adds V3 GEMM pipeline for MX FP4 with preshuffling
Adds MXFP4 GEMM tests (#2275)
Adds MXFP4 GEMM examples
Adds MXFP4 GEMMs to ckProfiler




Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>
Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: OscarXu <huaiguxu@amd.com>
Co-authored-by: mtgu0705 <mtgu@amd.com>
Co-authored-by: Ding, Yi <yi.ding@amd.com>
Co-authored-by: feifei14119 <feiw@amd.com>
Co-authored-by: Lin, Qun <qlin@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
This commit is contained in:
Andriy Roshchenko
2025-06-05 13:54:15 -06:00
committed by GitHub
parent 233e274077
commit 00247e3c29
83 changed files with 8193 additions and 2165 deletions

View File

@@ -260,7 +260,8 @@ struct ThreadwiseTensorSliceTransfer_v2
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>)
{
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
}
@@ -422,6 +423,240 @@ struct ThreadwiseTensorSliceTransfer_v2
SrcCoord src_coord_;
}; // namespace ck
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename DimAccessOrder,
index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun,
index_t scale_gather_num,
bool InvalidElementAsNaN = false,
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v2_gather
{
static_assert((InvalidElementAsNaN && !ck::is_integral<DstData>::value) ||
(!InvalidElementAsNaN),
"Filling invalid element as NaN is only for floating point types");
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
__device__ constexpr ThreadwiseTensorSliceTransfer_v2_gather(
const SrcDesc& src_desc,
const Index& src_slice_origin_idx,
const StaticallyIndexedArray<index_t, scale_gather_num>& scale_gather_offsets)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)),
scale_gather_offsets_(scale_gather_offsets)
{
static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
auto adjusted_origin_idx = [&]() {
Index idx;
static_for<0, nDim, 1>{}(
[&](auto i) { idx(i) = i.value == 0 ? 0 : src_slice_origin_idx[Number<i>{}]; });
return idx;
}();
src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx);
}
template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc&,
const DstSliceOriginIdx&,
DstBuffer& dst_buf)
{
static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! DstDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
"wrong! DstSliceOrigin need to known at compile-time");
static_assert(
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value &&
"wrong! inconsistent type");
// DstDesc and dst_slice_origin_idx are known at compile-time
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>>;
// loop over tensor and copy
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) {
constexpr auto current_dst_origin =
to_multi_index(dst_slice_origin_idx) + make_multi_index(gather_idx, 0);
static_for<0, num_access, 1>{}([&](auto idx_1d) {
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type
src_vector;
using src_vector_t =
typename vector_type_maker<SrcData,
SrcScalarPerVector / PackedSize>::type::type;
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
src_coord_);
// copy data from src_buf into src_vector
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize +
scale_gather_offsets_(gather_idx),
is_src_valid);
// copy data from src_vector into dst_buf
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
constexpr index_t dst_offset =
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) +
src_data_idx + i * src_scalar_step_in_vector);
constexpr auto full_dst_offset =
dst_desc.CalculateOffset(current_dst_origin) + dst_offset;
if constexpr(InvalidElementAsNaN)
{
dst_buf(full_dst_offset) =
is_src_valid
? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
: NumericLimits<DstData>::QuietNaN();
}
else
{
dst_buf(Number<full_dst_offset>{}) =
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
}
});
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
}
});
});
// printf("blockIdx.y: %d, tid: %d, dst_buf<%f>\n",
// blockIdx.y,
// threadIdx.x,
// dst_buf(Number<0>{}));
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>>;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)
{
return typename SpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_step;
}
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
private:
SrcCoord src_coord_;
StaticallyIndexedArray<index_t, scale_gather_num> scale_gather_offsets_;
}; // namespace ck
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
@@ -1053,10 +1288,8 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>)
{
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
}
@@ -1236,16 +1469,16 @@ struct ThreadwiseTensorSliceTransfer_v4
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
vector_type_maker_t<DstData, SrcScalarPerVector / PackedSize> dst_tmp_vector;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
dst_tmp_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);

View File

@@ -62,6 +62,18 @@ struct lambda_scalar_per_access_for_src_and_dst
}
};
template <index_t WaveNum, index_t nDim>
struct lambda_wave_cluster_dimension
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
if((nDim - i) == 3)
return WaveNum;
else
return 1;
}
};
} // namespace detail
} // namespace ck

View File

@@ -90,7 +90,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_(src_element_op),
dst_element_op_(dst_element_op)
{
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
if constexpr((packed_size_v<SrcData>) > 1)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
@@ -99,7 +99,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
static_assert(SrcVectorDim == DstVectorDim,
"Packed data type does not support transpose");
}
}
@@ -444,6 +445,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{
static_assert(!is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
"in-register transpose is not supported for pk_i4_t");
static_assert(!is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>,
"in-register transpose is not supported for f4x2_pk_t");
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -96,7 +96,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
dst_element_op_(dst_element_op),
gather_offsets_(gather_offsets)
{
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
if constexpr((packed_size_v<SrcData>) > 1)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
@@ -105,7 +105,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
static_assert(SrcVectorDim == DstVectorDim,
"Packed data type does not support transpose");
}
}
@@ -222,7 +223,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
auto gather_offset =
gather_offsets_(ordered_src_access_idx[Number<ordered_gather_dim>{}]);
const IndexType ld_offset = src_coord_.GetOffset() + gather_offset;
const IndexType ld_offset = src_coord_.GetOffset() / PackedSize + gather_offset;
src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, true);

View File

@@ -410,8 +410,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
// dst_coords_[i]);
constexpr InMemoryDataOperationEnum DstInMemOp =
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(