mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
Ck int4 moe develop (#1949)
* Add Gemm fp8xint4 example and kernel, function pass. * Init Gemm_fp8xint4 Bpreshuffle * Added gemm_fp8xint4_Bpreshuffle files, function not checked yet * General fix. * fp8xint4 bpreshuffle function pass * fix. * init b preshuffle dequant in VGPR. * fix bug, function pass. * move b thread dequant copy to blockwise. * fix bug, function now passes. * modified the tile size to 256, 128x128x128. * fixed a bug. * Initial int4 moe, compile pass, function not check. * fix bug in moe_gemm1.cpp, now function pass. * test expert = 8 and function pass. * Added moe_pk_i4_gemm2, function pass. * Added b preshuffle pipeline v3 support. * fixed merge issue. fp8xint4 and fp8xint4_bpreshuffle function pass. * Split the blockwise pipeline for fp8xint4. * commit missing files * opt gemm2 to 2x2 wave * fix swizzle = false * update int4 moe with latest input changes. * update tile size. * enable pipeline v3. * fix nswizzle = true * commit a version for compiler debug. * Updated transfer_v3r1_gather to support pk_i4_t type. * for int4 moe2 for type_convert support. * remove some values between mfma instructions. * fix int4 moe * Updated transfer_v3r1_gather to support pk_i4_t type. * i4 support lds multiple shuffle * fixed int4 moe tflops calculation. * Modified CshuffleCShuffleMXdlPerWavePerShuffle to 1 to suit C multiple shuffle * updated gemm2. * change int4 moe example names * fix and format code. * format. * format codes. * update fp8xint4 example tile size. * add <unordered_map> header * fixed. * format. * Added conditional compilation for int4 -> fp8 conversion kernels --------- Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
@@ -224,6 +224,13 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
|
||||
static constexpr index_t PackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_idx)
|
||||
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx))
|
||||
@@ -232,6 +239,11 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
"wrong! SrcDesc need to known at compile-time");
|
||||
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
{
|
||||
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
@@ -276,10 +288,10 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type src_vector;
|
||||
|
||||
using src_vector_t =
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type::type;
|
||||
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
const bool is_src_valid =
|
||||
@@ -287,10 +299,11 @@ struct ThreadwiseTensorSliceTransfer_v2
|
||||
|
||||
// copy data from src_buf into src_vector
|
||||
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize,
|
||||
is_src_valid);
|
||||
|
||||
// copy data from src_vector into dst_buf
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset =
|
||||
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
|
||||
i * src_scalar_step_in_vector);
|
||||
@@ -1465,6 +1478,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
static constexpr index_t PackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic(
|
||||
const ElementwiseOperation& element_op)
|
||||
: element_op_{element_op}
|
||||
@@ -1485,7 +1505,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf)
|
||||
DstBuffer& dst_buf) const
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc need to known at compile-time");
|
||||
@@ -1519,26 +1539,71 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
|
||||
{
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
typename vector_type_maker<SrcData, DstScalarPerVector / PackedSize>::type
|
||||
src_tmp_vector;
|
||||
|
||||
// copy data from src_buf into dst_vector
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
// copy data from src_buf into dst_vector
|
||||
static_for<0, DstScalarPerVector / PackedSize, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
DstData v;
|
||||
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
|
||||
});
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v, src_buf[Number<src_offset>{}]);
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, DstScalarPerVector> dst_tmp_vector;
|
||||
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset>{}) = v;
|
||||
constexpr index_t pack_size = 8;
|
||||
|
||||
static_assert(DstScalarPerVector % pack_size == 0, "");
|
||||
|
||||
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
|
||||
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
|
||||
|
||||
static_for<0, DstScalarPerVector / pack_size, 1>{}([&](auto i) {
|
||||
ck::tensor_operation::element_wise::PassThroughPack8{}(
|
||||
dst_tmp_vector.template AsType<dst_v_t>()(i),
|
||||
src_tmp_vector.template AsType<src_v_t>()[i]);
|
||||
});
|
||||
|
||||
// copy data from dst_tmp_vector into dst_buf
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
// copy data from src_buf into dst_vector
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
DstData v;
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v, src_buf[Number<src_offset>{}]);
|
||||
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset>{}) = v;
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
ElementwiseOperation element_op_;
|
||||
|
||||
@@ -31,8 +31,8 @@ template <typename SliceLengths,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t DstScalarPerVector,
|
||||
index_t SrcScalarPerVector_,
|
||||
index_t DstScalarPerVector_,
|
||||
index_t SrcScalarStrideInVector,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
|
||||
@@ -54,7 +54,31 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
static constexpr auto I8 = Number<8>{};
|
||||
static constexpr auto I10 = Number<10>{};
|
||||
static constexpr auto I12 = Number<12>{};
|
||||
static constexpr auto I13 = Number<13>{};
|
||||
static constexpr auto I14 = Number<14>{};
|
||||
static constexpr auto I16 = Number<16>{};
|
||||
|
||||
static constexpr index_t PackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
|
||||
static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};
|
||||
|
||||
static constexpr index_t gather_num = SliceLengths{}.At(Number<GatherDim>{});
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_gather(
|
||||
@@ -71,6 +95,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
dst_element_op_(dst_element_op),
|
||||
gather_offsets_(gather_offsets)
|
||||
{
|
||||
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
|
||||
{
|
||||
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
|
||||
"SrcData != DstData");
|
||||
|
||||
static_assert(
|
||||
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
|
||||
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
|
||||
|
||||
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
@@ -107,10 +142,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
|
||||
|
||||
static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0,
|
||||
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
@@ -212,17 +248,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
|
||||
return math::min(8, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack4_invocable_t, decltype(src_element_op_)>::value)
|
||||
else if constexpr(is_detected<is_pack4_invocable_t,
|
||||
decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
|
||||
return math::min(4, SrcScalarPerVector);
|
||||
}
|
||||
if constexpr(is_detected<is_pack2_invocable_t, decltype(src_element_op_)>::value)
|
||||
else if constexpr(is_detected<is_pack2_invocable_t,
|
||||
decltype(src_element_op_)>::value)
|
||||
{
|
||||
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
|
||||
return math::min(2, SrcScalarPerVector);
|
||||
}
|
||||
return 1;
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
|
||||
@@ -306,7 +347,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
|
||||
// OOB Check
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -377,6 +418,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
|
||||
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
|
||||
{
|
||||
static_assert(!is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
|
||||
"in-register transpose is not supported for pk_i4_t");
|
||||
// each transpose does
|
||||
// DstScalarPerVector # of src vectors in src_thread_scratch_
|
||||
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
|
||||
@@ -437,7 +480,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
}
|
||||
else
|
||||
{
|
||||
static_ford<SliceLengths>{}([&](auto idx) {
|
||||
constexpr auto packed_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
|
||||
|
||||
static_ford<decltype(packed_access_lengths)>{}([&](auto idx) {
|
||||
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
|
||||
});
|
||||
}
|
||||
@@ -465,7 +513,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
// src scalar per access on each dim
|
||||
// TODO: don't use this
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
@@ -559,7 +607,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
|
||||
// copy data from dst_vector_container to dst_buf
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
dst_coord_.GetOffset() / PackedSize,
|
||||
is_dst_valid,
|
||||
dst_vector_container.template AsType<dst_vector_t>()[I0]);
|
||||
|
||||
@@ -613,7 +661,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -672,7 +720,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
// scalar per access on each dim
|
||||
// TODO: don't use lambda_scalar_per_access
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
@@ -757,7 +805,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -806,7 +854,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
|
||||
{
|
||||
constexpr auto src_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
|
||||
|
||||
@@ -817,7 +865,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
{
|
||||
// 1st stage of transforms
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user