mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Dev/a8w4 and a8w8splitk (#3447)
* Ck moe bs splitk pr (#3440) * splitk kick-off. Compilation fail * splitk hack pass * fix scale offset calc. * clang-format for a8w8_moe_blk_gemm1 splitk change * fix testcase error --------- Co-authored-by: oscar <huaiguxu@amd.com> Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com> * Zan/moe a8w4 (#3441) * update * update * update ck moe a8w4 * update * update * update * compile pass * update * update * python3 op_tests/test_moe_2stage.py -t 16 -e 1 -k 1 -dim 256,256 ready * support new a8w4 kernel * update * update ck_tile * re format * update * update * fix conflict * fix build * update ck_tile moe * fix clang format * fix the problem * fix accruacy issue * fix --------- Co-authored-by: oscar <huaiguxu@amd.com> Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com> Co-authored-by: Zzz9990 <zanzhang@amd.com> Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
@@ -60,8 +60,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
karg.p_a_scale_grid,
|
||||
karg.p_b_scale_grid,
|
||||
karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
@@ -101,8 +101,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
karg.p_a_scale_grid,
|
||||
karg.p_b_scale_grid,
|
||||
karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
|
||||
p_shared,
|
||||
p_shared1,
|
||||
karg,
|
||||
@@ -167,6 +167,7 @@ template <typename ALayout,
|
||||
index_t ActivationOperation = 0,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
bool IsSplitK = false,
|
||||
bool MulRoutedWeight = true,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = CDataType,
|
||||
@@ -249,13 +250,15 @@ struct GridwiseMoeGemmBlockScale
|
||||
return 1;
|
||||
}();
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N)
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t K, index_t KBatch)
|
||||
{
|
||||
const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
|
||||
const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
|
||||
const index_t gridx = NSwizzle ? nblock * mblock : nblock;
|
||||
const index_t gridy = NSwizzle ? 1 : mblock;
|
||||
return std::make_tuple(gridx, gridy, 1);
|
||||
const index_t gridz = KBatch == 1 ? 1 : math::integer_divide_ceil(K, KPerBlock * KBatch);
|
||||
|
||||
return std::make_tuple(gridx, gridy, gridz);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateMPadded(index_t M)
|
||||
@@ -284,27 +287,32 @@ struct GridwiseMoeGemmBlockScale
|
||||
|
||||
__host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
|
||||
// auto K_t = K_Batch * KPerBlock;
|
||||
// return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
|
||||
return K_Batch == 1 ? K / AK1Value : K_Batch * KPerBlock / AK1Value;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
|
||||
// auto K_t = K_Batch * KPerBlock;
|
||||
// return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
|
||||
return K_Batch == 1 ? K / BK1Value : K_Batch * KPerBlock / BK1Value;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * KPerBlock;
|
||||
// auto K_t = K_Batch * KPerBlock;
|
||||
// return (K + K_t - 1) / K_t * KPerBlock;
|
||||
return K_Batch == 1 ? K : K_Batch * KPerBlock;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
|
||||
auto K_t = K_Batch * KReadVec;
|
||||
return (K + K_t - 1) / K_t * KReadVec;
|
||||
// auto K_t = K_Batch * KReadVec;
|
||||
// return (K + K_t - 1) / K_t * KReadVec;
|
||||
return K_Batch == 1 ? math::integer_divide_ceil(K, KReadVec) * KReadVec
|
||||
: K_Batch * KPerBlock;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateMBlock(index_t M)
|
||||
@@ -409,7 +417,6 @@ struct GridwiseMoeGemmBlockScale
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
}
|
||||
@@ -741,35 +748,41 @@ struct GridwiseMoeGemmBlockScale
|
||||
{
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = k_id * karg.KRead / APackedSize;
|
||||
a_k_split_offset = k_id * karg.KRead / APackedSize;
|
||||
ascale_k_split_offset = math::integer_divide_floor(a_k_split_offset, ScaleBlockK);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
|
||||
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
|
||||
ascale_k_split_offset = math::integer_divide_floor(a_k_split_offset, ScaleBlockK);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
|
||||
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
|
||||
bscale_k_split_offset = math::integer_divide_floor(b_k_split_offset, ScaleBlockK);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
// KPack * NLane * KLane * K0 * N0
|
||||
b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
|
||||
b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
|
||||
bscale_k_split_offset = k_id * karg.KRead / ScaleBlockK;
|
||||
}
|
||||
|
||||
if(k_id < karg.KBatch - 1)
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
}
|
||||
else
|
||||
{
|
||||
karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
|
||||
}
|
||||
// if(k_id < karg.KBatch - 1)
|
||||
// {
|
||||
// karg.K = karg.KRead;
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
|
||||
// }
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t ascale_k_split_offset;
|
||||
index_t bscale_k_split_offset;
|
||||
};
|
||||
|
||||
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -912,8 +925,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
}
|
||||
|
||||
using BlockwiseGemmPipe =
|
||||
remove_cvref_t<decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector<
|
||||
BlkGemmPipelineVer,
|
||||
remove_cvref_t<decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector <
|
||||
BlkGemmPipelineVer,
|
||||
BlkGemmPipeSched,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -939,7 +952,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
IsInputGemm>())>;
|
||||
IsInputGemm && !IsSplitK > ())>;
|
||||
|
||||
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
@@ -1189,9 +1202,9 @@ struct GridwiseMoeGemmBlockScale
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
ignore = b_element_op;
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
ignore = b_element_op;
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1));
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
|
||||
problem.MPadded,
|
||||
@@ -1204,8 +1217,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
|
||||
IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
|
||||
problem.MPadded,
|
||||
problem.N,
|
||||
problem.NPadded,
|
||||
problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
problem.StrideC);
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
@@ -1215,7 +1228,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
|
||||
make_tuple(math::integer_divide_ceil(problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
ScaleBlockN),
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
|
||||
|
||||
@@ -1371,9 +1385,10 @@ struct GridwiseMoeGemmBlockScale
|
||||
decltype(c_thread_buf) c_thread_buf_up;
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
problem.KBatch == 1
|
||||
? (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock
|
||||
: problem.KBatch);
|
||||
constexpr index_t ScaleSliceSizeM = MXdlPerWave;
|
||||
constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
|
||||
constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
|
||||
@@ -1447,7 +1462,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
|
||||
|
||||
constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
|
||||
if constexpr(IsInputGemm)
|
||||
if constexpr(IsInputGemm && !IsSplitK)
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -1606,7 +1621,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, n2 * N4 + n4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
if constexpr(IsInputGemm) // gu fusion, elementwise
|
||||
if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
@@ -1743,8 +1758,12 @@ struct GridwiseMoeGemmBlockScale
|
||||
|
||||
using EDataType = CDataType;
|
||||
|
||||
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
|
||||
const auto ds_grid_desc_m_n =
|
||||
MakeDsGridDescriptor_M_N(problem.M,
|
||||
problem.MPadded,
|
||||
problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
problem.StrideDs);
|
||||
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -1875,7 +1894,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_offsets(m0) =
|
||||
token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1);
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
@@ -1953,8 +1973,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
|
||||
IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
|
||||
problem.MPadded,
|
||||
problem.N,
|
||||
problem.NPadded,
|
||||
problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
problem.StrideC);
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
@@ -2125,8 +2145,10 @@ struct GridwiseMoeGemmBlockScale
|
||||
decltype(c_thread_buf) c_thread_buf_up;
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
problem.KBatch == 1
|
||||
? (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock
|
||||
: problem.KBatch);
|
||||
|
||||
// scale
|
||||
constexpr index_t ScaleSliceSizeM = MXdlPerWave;
|
||||
@@ -2202,7 +2224,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
|
||||
|
||||
constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
|
||||
if constexpr(IsInputGemm)
|
||||
if constexpr(IsInputGemm && !IsSplitK)
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -2352,7 +2374,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, n2 * N4 + n4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
if constexpr(IsInputGemm) // gu fusion, elementwise
|
||||
if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
@@ -2619,7 +2641,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_offsets(m0) =
|
||||
token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1);
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
Reference in New Issue
Block a user