Ck int4 moe develop (#1949)

* Add Gemm fp8xint4 example and kernel, function pass.

* Init Gemm_fp8xint4 Bpreshuffle

* Added gemm_fp8xint4_Bpreshuffle files, function not checked yet

* General fix.

* fp8xint4 bpreshuffle function pass

* fix.

* init b preshuffle dequant in VGPR.

* fix bug, function pass.

* move b thread dequant copy to blockwise.

* fix bug, function now passes.

* modified the tile size to 256, 128x128x128.

* fixed a bug.

* Initial int4 moe, compile pass, function not check.

* fix bug in moe_gemm1.cpp, now function pass.

* test expert = 8 and function pass.

* Added moe_pk_i4_gemm2, function pass.

* Added b preshuffle pipeline v3 support.

* fixed merge issue. fp8xint4 and fp8xint4_bpreshuffle function pass.

* Split the blockwise pipeline for fp8xint4.

* commit missing files

* opt gemm2 to 2x2 wave

* fix swizzle = false

* update int4 moe with latest input changes.

* update tile size.

* enable pipeline v3.

* fix nswizzle = true

* commit a version for compiler debug.

* Updated transfer_v3r1_gather to support pk_i4_t type.

* for int4 moe2 for type_convert support.

* remove some values between mfma instructions.

* fix int4 moe

* Updated transfer_v3r1_gather to support pk_i4_t type.

* i4 support lds multiple shuffle

* fixed int4 moe tflops calculation.

* Modified CshuffleCShuffleMXdlPerWavePerShuffle to 1 to suit C multiple shuffle

* updated gemm2.

* change int4 moe example names

* fix and format code.

* format.

* format codes.

* update fp8xint4 example tile size.

* add <unordered_map> header

* fixed.

* format.

* Added conditional compilation for int4 -> fp8 conversion kernels

---------

Co-authored-by: mtgu0705 <mtgu@amd.com>
Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
Mingtao Gu
2025-03-10 11:16:44 +08:00
committed by GitHub
parent c954bd0cfa
commit 0db7c8f0b2
19 changed files with 6018 additions and 83 deletions

View File

@@ -224,6 +224,13 @@ struct ThreadwiseTensorSliceTransfer_v2
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc,
const Index& src_slice_origin_idx)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx))
@@ -232,6 +239,11 @@ struct ThreadwiseTensorSliceTransfer_v2
"wrong! SrcDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
@@ -276,10 +288,10 @@ struct ThreadwiseTensorSliceTransfer_v2
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, num_access, 1>{}([&](auto idx_1d) {
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type src_vector;
using src_vector_t =
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type::type;
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
const bool is_src_valid =
@@ -287,10 +299,11 @@ struct ThreadwiseTensorSliceTransfer_v2
// copy data from src_buf into src_vector
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize,
is_src_valid);
// copy data from src_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
constexpr index_t dst_offset =
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector);
@@ -1465,6 +1478,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
using Index = MultiIndex<nDim>;
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic(
const ElementwiseOperation& element_op)
: element_op_{element_op}
@@ -1485,7 +1505,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
const SrcBuffer& src_buf,
const DstDesc&,
const DstSliceOriginIdx&,
DstBuffer& dst_buf)
DstBuffer& dst_buf) const
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
@@ -1519,26 +1539,71 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, num_access, 1>{}([&](auto idx_1d) {
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
{
static_for<0, num_access, 1>{}([&](auto idx_1d) {
typename vector_type_maker<SrcData, DstScalarPerVector / PackedSize>::type
src_tmp_vector;
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector / PackedSize, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
DstData v;
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
});
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, DstScalarPerVector> dst_tmp_vector;
// apply type convert
dst_buf(Number<dst_offset>{}) = v;
constexpr index_t pack_size = 8;
static_assert(DstScalarPerVector % pack_size == 0, "");
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
static_for<0, DstScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack8{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
});
});
}
else
{
static_for<0, num_access, 1>{}([&](auto idx_1d) {
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
DstData v;
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert
dst_buf(Number<dst_offset>{}) = v;
});
});
}
}
ElementwiseOperation element_op_;

View File

@@ -31,8 +31,8 @@ template <typename SliceLengths,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector,
index_t SrcScalarPerVector_,
index_t DstScalarPerVector_,
index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
@@ -54,7 +54,31 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I10 = Number<10>{};
static constexpr auto I12 = Number<12>{};
static constexpr auto I13 = Number<13>{};
static constexpr auto I14 = Number<14>{};
static constexpr auto I16 = Number<16>{};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};
static constexpr index_t gather_num = SliceLengths{}.At(Number<GatherDim>{});
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_gather(
@@ -71,6 +95,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
dst_element_op_(dst_element_op),
gather_offsets_(gather_offsets)
{
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
static_assert(
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
@@ -107,10 +142,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
@@ -212,17 +248,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value)
else if constexpr(is_detected<is_pack4_invocable_t,
decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector);
}
if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value)
else if constexpr(is_detected<is_pack2_invocable_t,
decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector);
}
return 1;
else
{
return 1;
}
};
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
@@ -306,7 +347,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// OOB Check
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
@@ -377,6 +418,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{
static_assert(!is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
"in-register transpose is not supported for pk_i4_t");
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
@@ -437,7 +480,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
}
else
{
static_ford<SliceLengths>{}([&](auto idx) {
constexpr auto packed_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
static_ford<decltype(packed_access_lengths)>{}([&](auto idx) {
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
});
}
@@ -465,7 +513,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// src scalar per access on each dim
// TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
@@ -559,7 +607,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// copy data from dst_vector_container to dst_buf
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
dst_coord_.GetOffset() / PackedSize,
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
@@ -613,7 +661,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
@@ -672,7 +720,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
@@ -757,7 +805,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
@@ -806,7 +854,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
@@ -817,7 +865,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
{
// 1st stage of transforms
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;