mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Optimized GEMMs for MX FP4/8 (#2294)
Adds V3 GEMM pipeline for MX FP4 and MX FP8 Adds V3 GEMM pipeline for MX FP4 with preshuffling Adds MXFP4 GEMM tests (#2275) Adds MXFP4 GEMM examples Adds MXFP4 GEMMs to ckProfiler Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com> Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com> Co-authored-by: OscarXu <huaiguxu@amd.com> Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: Ding, Yi <yi.ding@amd.com> Co-authored-by: feifei14119 <feiw@amd.com> Co-authored-by: Lin, Qun <qlin@amd.com> Co-authored-by: joye <joye@amd.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
233e274077
commit
00247e3c29
@@ -260,7 +260,8 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>)
|
||||
{
|
||||
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
|
||||
}
|
||||
@@ -422,6 +423,240 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
SrcCoord src_coord_;
|
||||
}; // namespace ck
|
||||
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
bool SrcResetCoordinateAfterRun,
|
||||
index_t scale_gather_num,
|
||||
bool InvalidElementAsNaN = false,
|
||||
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v2_gather
|
||||
{
|
||||
static_assert((InvalidElementAsNaN && !ck::is_integral<DstData>::value) ||
|
||||
(!InvalidElementAsNaN),
|
||||
"Filling invalid element as NaN is only for floating point types");
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
|
||||
static constexpr index_t PackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v2_gather(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_idx,
|
||||
const StaticallyIndexedArray<index_t, scale_gather_num>& scale_gather_offsets)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)),
|
||||
scale_gather_offsets_(scale_gather_offsets)
|
||||
{
|
||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
{
|
||||
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
auto adjusted_origin_idx = [&]() {
|
||||
Index idx;
|
||||
|
||||
static_for<0, nDim, 1>{}(
|
||||
[&](auto i) { idx(i) = i.value == 0 ? 0 : src_slice_origin_idx[Number<i>{}]; });
|
||||
|
||||
return idx;
|
||||
}();
|
||||
|
||||
src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
static_assert(DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! DstDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
|
||||
"wrong! DstSliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(
|
||||
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
// DstDesc and dst_slice_origin_idx are known at compile-time
|
||||
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
|
||||
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
|
||||
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_scalar_step_in_vector =
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(src_scalar_per_access)>>;
|
||||
|
||||
// loop over tensor and copy
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) {
|
||||
constexpr auto current_dst_origin =
|
||||
to_multi_index(dst_slice_origin_idx) + make_multi_index(gather_idx, 0);
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type
|
||||
src_vector;
|
||||
|
||||
using src_vector_t =
|
||||
typename vector_type_maker<SrcData,
|
||||
SrcScalarPerVector / PackedSize>::type::type;
|
||||
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
|
||||
src_coord_);
|
||||
|
||||
// copy data from src_buf into src_vector
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize +
|
||||
scale_gather_offsets_(gather_idx),
|
||||
is_src_valid);
|
||||
|
||||
// copy data from src_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset =
|
||||
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) +
|
||||
src_data_idx + i * src_scalar_step_in_vector);
|
||||
constexpr auto full_dst_offset =
|
||||
dst_desc.CalculateOffset(current_dst_origin) + dst_offset;
|
||||
|
||||
if constexpr(InvalidElementAsNaN)
|
||||
{
|
||||
dst_buf(full_dst_offset) =
|
||||
is_src_valid
|
||||
? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
|
||||
: NumericLimits<DstData>::QuietNaN();
|
||||
}
|
||||
else
|
||||
{
|
||||
dst_buf(Number<full_dst_offset>{}) =
|
||||
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(idx_1d.value != num_access - 1)
|
||||
{
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
|
||||
|
||||
move_tensor_coordinate(
|
||||
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// printf("blockIdx.y: %d, tid: %d, dst_buf<%f>\n",
|
||||
// blockIdx.y,
|
||||
// threadIdx.x,
|
||||
// dst_buf(Number<0>{}));
|
||||
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_step =
|
||||
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(src_scalar_per_access)>>;
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
if constexpr(num_access == 0)
|
||||
{
|
||||
return typename SpaceFillingCurve::Index{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto reset_step =
|
||||
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
|
||||
|
||||
return reset_step;
|
||||
}
|
||||
}
|
||||
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx)
|
||||
{
|
||||
// if src coord was not reset by Run(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
template <typename SrcMoveSliceWindowStepHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx,
|
||||
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_tensor_coordinate_step(
|
||||
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
|
||||
|
||||
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
SrcCoord src_coord_;
|
||||
StaticallyIndexedArray<index_t, scale_gather_num> scale_gather_offsets_;
|
||||
}; // namespace ck
|
||||
|
||||
// Assume:
|
||||
// 1. src_desc and dst_desc are not known at compile-time
|
||||
// 2. SrcBuffer and DstBuffer are DynamicBuffer
|
||||
@@ -1053,10 +1288,8 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||
|
||||
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
|
||||
is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>)
|
||||
{
|
||||
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
|
||||
}
|
||||
@@ -1236,16 +1469,16 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
{
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector / PackedSize> dst_tmp_vector;
|
||||
|
||||
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
|
||||
dst_tmp_vector.template AsType<DstData>()(i) =
|
||||
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
|
||||
});
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
|
||||
|
||||
|
||||
@@ -62,6 +62,18 @@ struct lambda_scalar_per_access_for_src_and_dst
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t WaveNum, index_t nDim>
|
||||
struct lambda_wave_cluster_dimension
|
||||
{
|
||||
__host__ __device__ constexpr auto operator()(index_t i) const
|
||||
{
|
||||
if((nDim - i) == 3)
|
||||
return WaveNum;
|
||||
else
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -90,7 +90,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
src_element_op_(src_element_op),
|
||||
dst_element_op_(dst_element_op)
|
||||
{
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
if constexpr((packed_size_v<SrcData>) > 1)
|
||||
{
|
||||
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
|
||||
"SrcData != DstData");
|
||||
@@ -99,7 +99,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
|
||||
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
|
||||
|
||||
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
|
||||
static_assert(SrcVectorDim == DstVectorDim,
|
||||
"Packed data type does not support transpose");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -444,6 +445,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
{
|
||||
static_assert(!is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
|
||||
"in-register transpose is not supported for pk_i4_t");
|
||||
static_assert(!is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>,
|
||||
"in-register transpose is not supported for f4x2_pk_t");
|
||||
// each transpose does
|
||||
// DstScalarPerVector # of src vectors in src_thread_scratch_
|
||||
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -96,7 +96,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
dst_element_op_(dst_element_op),
|
||||
gather_offsets_(gather_offsets)
|
||||
{
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
if constexpr((packed_size_v<SrcData>) > 1)
|
||||
{
|
||||
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
|
||||
"SrcData != DstData");
|
||||
@@ -105,7 +105,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
|
||||
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
|
||||
|
||||
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
|
||||
static_assert(SrcVectorDim == DstVectorDim,
|
||||
"Packed data type does not support transpose");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,7 +223,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
auto gather_offset =
|
||||
gather_offsets_(ordered_src_access_idx[Number<ordered_gather_dim>{}]);
|
||||
|
||||
const IndexType ld_offset = src_coord_.GetOffset() + gather_offset;
|
||||
const IndexType ld_offset = src_coord_.GetOffset() / PackedSize + gather_offset;
|
||||
src_oob_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<bool>(src_data_idx_seq, true);
|
||||
|
||||
|
||||
@@ -410,8 +410,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
|
||||
IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
|
||||
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
|
||||
// dst_coords_[i]);
|
||||
constexpr InMemoryDataOperationEnum DstInMemOp =
|
||||
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
|
||||
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
|
||||
|
||||
Reference in New Issue
Block a user