Dev/a8w4 and a8w8splitk (#3447)

* Ck moe bs splitk pr (#3440)

* splitk kick-off. Compilation fail

* splitk hack pass

* fix scale offset calc.

* clang-format for a8w8_moe_blk_gemm1 splitk change

* fix testcase error

---------

Co-authored-by: oscar <huaiguxu@amd.com>
Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com>

* Zan/moe a8w4 (#3441)

* update

* update

* update ck moe a8w4

* update

* update

* update

* compile pass

* update

* update

* python3 op_tests/test_moe_2stage.py -t 16 -e 1 -k 1 -dim 256,256 ready

* support new a8w4 kernel

* update

* update ck_tile

* re format

* update

* update

* fix conflict

* fix build

* update ck_tile moe

* fix clang format

* fix the problem

* fix accruacy issue

* fix

---------

Co-authored-by: oscar <huaiguxu@amd.com>
Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com>
Co-authored-by: Zzz9990 <zanzhang@amd.com>
Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
yadaish
2025-12-19 09:26:52 +08:00
committed by GitHub
parent ba897f8435
commit c0ee71d735
13 changed files with 2911 additions and 139 deletions

View File

@@ -360,6 +360,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
});
});
__builtin_amdgcn_sched_barrier(0);
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
@@ -550,6 +551,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
});
});
__builtin_amdgcn_sched_barrier(0);
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
@@ -677,6 +679,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
});
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {

View File

@@ -74,6 +74,7 @@ template <typename ALayout,
index_t ActivationOP = 0,
bool NSwizzle = false,
bool IsInputGemm = true,
bool IsSplitK = false,
bool MulRoutedWeight = false,
typename IndexType = index_t,
typename ComputeTypeA = CDataType,
@@ -156,6 +157,7 @@ struct DeviceMoeGemmBlockScale
ActivationOP,
NSwizzle,
IsInputGemm,
IsSplitK,
MulRoutedWeight,
IndexType,
ComputeTypeA,
@@ -201,12 +203,12 @@ struct DeviceMoeGemmBlockScale
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
arg.M, arg.N * (IsInputGemm && IsSplitK ? 2 : 1), arg.K, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
index_t K_split = arg.KBatch == 1 ? arg.K : arg.KBatch * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto RunKernel = [&](const auto& kernel) {
@@ -249,11 +251,12 @@ struct DeviceMoeGemmBlockScale
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
// if(arg_.KBatch > 1)
// hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
// 0,
// arg_.M * arg_.N * sizeof(CDataType)
// * (IsInputGemm && IsSplitK ? 2 : 1),
// stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
@@ -267,11 +270,12 @@ struct DeviceMoeGemmBlockScale
}
else
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
// if(arg.KBatch > 1)
// hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
// 0,
// arg.M * arg.N * sizeof(CDataType) *
// (IsInputGemm && IsSplitK ? 2 : 1),
// stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
@@ -289,8 +293,9 @@ struct DeviceMoeGemmBlockScale
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
constexpr auto MemoryDataOp =
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
constexpr auto MemoryDataOp = (IsInputGemm && !IsSplitK)
? InMemoryDataOperationEnum::Set
: InMemoryDataOperationEnum::AtomicAdd;
if(has_main_k_block_loop)
{
@@ -416,8 +421,8 @@ struct DeviceMoeGemmBlockScale
static bool IsSupportedArgument(const Argument& arg)
{
// only impl kbatch 1 now
if(arg.KBatch > 1)
// only impl kbatch 1 for fp32
if(arg.KBatch > 1 && !std::is_same_v<CDataType, float>)
{
return false;
}
@@ -441,6 +446,11 @@ struct DeviceMoeGemmBlockScale
{
return false;
}
if(arg.KBatch > 1 && arg.K % (KPerBlock * arg.KBatch) != 0)
{
// Not support Kpadding with KBatch > 1
return false;
}
if(get_warp_size() == 64)
{

View File

@@ -60,8 +60,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
karg.p_a_scale_grid,
karg.p_b_scale_grid,
karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
p_shared,
karg,
karg.a_element_op,
@@ -101,8 +101,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
karg.p_a_scale_grid,
karg.p_b_scale_grid,
karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
p_shared,
p_shared1,
karg,
@@ -167,6 +167,7 @@ template <typename ALayout,
index_t ActivationOperation = 0,
bool NSwizzle = false,
bool IsInputGemm = true,
bool IsSplitK = false,
bool MulRoutedWeight = true,
typename IndexType = index_t,
typename ComputeTypeA = CDataType,
@@ -249,13 +250,15 @@ struct GridwiseMoeGemmBlockScale
return 1;
}();
__host__ static auto CalculateGridSize(index_t M, index_t N)
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t K, index_t KBatch)
{
const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
const index_t gridx = NSwizzle ? nblock * mblock : nblock;
const index_t gridy = NSwizzle ? 1 : mblock;
return std::make_tuple(gridx, gridy, 1);
const index_t gridz = KBatch == 1 ? 1 : math::integer_divide_ceil(K, KPerBlock * KBatch);
return std::make_tuple(gridx, gridy, gridz);
}
__host__ __device__ static auto CalculateMPadded(index_t M)
@@ -284,27 +287,32 @@ struct GridwiseMoeGemmBlockScale
__host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
// auto K_t = K_Batch * KPerBlock;
// return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
return K_Batch == 1 ? K / AK1Value : K_Batch * KPerBlock / AK1Value;
}
__host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
// auto K_t = K_Batch * KPerBlock;
// return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
return K_Batch == 1 ? K / BK1Value : K_Batch * KPerBlock / BK1Value;
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * KPerBlock;
// auto K_t = K_Batch * KPerBlock;
// return (K + K_t - 1) / K_t * KPerBlock;
return K_Batch == 1 ? K : K_Batch * KPerBlock;
}
__host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
{
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
auto K_t = K_Batch * KReadVec;
return (K + K_t - 1) / K_t * KReadVec;
// auto K_t = K_Batch * KReadVec;
// return (K + K_t - 1) / K_t * KReadVec;
return K_Batch == 1 ? math::integer_divide_ceil(K, KReadVec) * KReadVec
: K_Batch * KPerBlock;
}
__host__ __device__ static auto CalculateMBlock(index_t M)
@@ -409,7 +417,6 @@ struct GridwiseMoeGemmBlockScale
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
@@ -741,35 +748,41 @@ struct GridwiseMoeGemmBlockScale
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = k_id * karg.KRead / APackedSize;
a_k_split_offset = k_id * karg.KRead / APackedSize;
ascale_k_split_offset = math::integer_divide_floor(a_k_split_offset, ScaleBlockK);
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
ascale_k_split_offset = math::integer_divide_floor(a_k_split_offset, ScaleBlockK);
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
bscale_k_split_offset = math::integer_divide_floor(b_k_split_offset, ScaleBlockK);
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
// KPack * NLane * KLane * K0 * N0
b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
bscale_k_split_offset = k_id * karg.KRead / ScaleBlockK;
}
if(k_id < karg.KBatch - 1)
{
karg.K = karg.KRead;
}
else
{
karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
}
// if(k_id < karg.KBatch - 1)
// {
// karg.K = karg.KRead;
// }
// else
// {
// karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
// }
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t ascale_k_split_offset;
index_t bscale_k_split_offset;
};
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -912,8 +925,8 @@ struct GridwiseMoeGemmBlockScale
}
using BlockwiseGemmPipe =
remove_cvref_t<decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector<
BlkGemmPipelineVer,
remove_cvref_t<decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector <
BlkGemmPipelineVer,
BlkGemmPipeSched,
BlockSize,
ADataType,
@@ -939,7 +952,7 @@ struct GridwiseMoeGemmBlockScale
MXdlPerWave,
NXdlPerWave,
KPack,
IsInputGemm>())>;
IsInputGemm && !IsSplitK > ())>;
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
@@ -1189,9 +1202,9 @@ struct GridwiseMoeGemmBlockScale
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
ignore = b_element_op;
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
ignore = b_element_op;
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1));
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
problem.MPadded,
@@ -1204,8 +1217,8 @@ struct GridwiseMoeGemmBlockScale
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
problem.MPadded,
problem.N,
problem.NPadded,
problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
problem.StrideC);
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
@@ -1215,7 +1228,8 @@ struct GridwiseMoeGemmBlockScale
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
make_tuple(math::integer_divide_ceil(problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
@@ -1371,9 +1385,10 @@ struct GridwiseMoeGemmBlockScale
decltype(c_thread_buf) c_thread_buf_up;
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
problem.KBatch == 1
? (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock
: problem.KBatch);
constexpr index_t ScaleSliceSizeM = MXdlPerWave;
constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
@@ -1447,7 +1462,7 @@ struct GridwiseMoeGemmBlockScale
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
if constexpr(IsInputGemm)
if constexpr(IsInputGemm && !IsSplitK)
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
@@ -1606,7 +1621,7 @@ struct GridwiseMoeGemmBlockScale
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, n2 * N4 + n4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion, elementwise
if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise
{
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
@@ -1743,8 +1758,12 @@ struct GridwiseMoeGemmBlockScale
using EDataType = CDataType;
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto ds_grid_desc_m_n =
MakeDsGridDescriptor_M_N(problem.M,
problem.MPadded,
problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
problem.StrideDs);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
@@ -1875,7 +1894,8 @@ struct GridwiseMoeGemmBlockScale
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_offsets(m0) =
token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1);
});
block_sync_lds();
@@ -1953,8 +1973,8 @@ struct GridwiseMoeGemmBlockScale
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
problem.MPadded,
problem.N,
problem.NPadded,
problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
problem.StrideC);
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
@@ -2125,8 +2145,10 @@ struct GridwiseMoeGemmBlockScale
decltype(c_thread_buf) c_thread_buf_up;
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
problem.KBatch == 1
? (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock
: problem.KBatch);
// scale
constexpr index_t ScaleSliceSizeM = MXdlPerWave;
@@ -2202,7 +2224,7 @@ struct GridwiseMoeGemmBlockScale
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
if constexpr(IsInputGemm)
if constexpr(IsInputGemm && !IsSplitK)
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
@@ -2352,7 +2374,7 @@ struct GridwiseMoeGemmBlockScale
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, n2 * N4 + n4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion, elementwise
if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise
{
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
@@ -2619,7 +2641,8 @@ struct GridwiseMoeGemmBlockScale
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_offsets(m0) =
token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1);
});
block_sync_lds();