splitk hack pass

This commit is contained in:
oscar
2025-11-28 17:41:29 +08:00
parent cb1bea4929
commit c79879c669
3 changed files with 64 additions and 44 deletions

View File

@@ -185,16 +185,21 @@ int main(int argc, char* argv[])
// GEMM shape
ck::index_t N = 4096;
ck::index_t K = 6144;
// ck::index_t N = 128;
// ck::index_t K = 512;
ck::index_t experts = 8;
ck::index_t topk = 2;
// ck::index_t sorted_tile_num = 515;
// ck::index_t valid_tile_num = 512;
// ck::index_t tokens = 8192;
// ck::index_t tokens = 208;
// ck::index_t sorted_tile_num = 15;
// ck::index_t valid_tile_num = 13;
ck::index_t sorted_tile_num = 259;
ck::index_t valid_tile_num = 256;
ck::index_t tokens = 4096;
// ck::index_t sorted_tile_num = 259;
// ck::index_t valid_tile_num = 256;
// ck::index_t tokens = 4096;
ck::index_t sorted_tile_num = 2;
ck::index_t valid_tile_num = 2;
ck::index_t tokens = 32;
#else
// deepseek
ck::index_t N = 2048;
@@ -256,14 +261,14 @@ int main(int argc, char* argv[])
}
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
ck::index_t StrideE = N * 2;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0};
ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K;
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N * 2;
ck::index_t KBatch = 1;
ck::index_t KBatch = 6;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
@@ -319,9 +324,9 @@ int main(int argc, char* argv[])
{
case 0: break;
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-1.0, 1.0});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-1.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break;
case 2:

View File

@@ -203,12 +203,11 @@ struct DeviceMoeGemmBlockScale
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N * (IsInputGemm && IsSplitK ? 2 : 1), arg.K, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
index_t K_split = arg.KBatch == 1 ? arg.K : arg.KBatch * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto RunKernel = [&](const auto& kernel) {
@@ -443,6 +442,11 @@ struct DeviceMoeGemmBlockScale
{
return false;
}
if (arg.KBatch > 1 && arg.K % (KPerBlock * arg.KBatch) != 0)
{
// Not support Kpadding with KBatch > 1
return false;
}
if(get_warp_size() == 64)
{

View File

@@ -51,6 +51,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
// printf("splitk_batch_offset.a_k_split_offset: %d\n", splitk_batch_offset.a_k_split_offset);
// printf("splitk_batch_offset.b_k_split_offset: %d\n", splitk_batch_offset.b_k_split_offset);
// printf("splitk_batch_offset.ascale_k_split_offset: %d\n", splitk_batch_offset.ascale_k_split_offset);
// printf("splitk_batch_offset.bscale_k_split_offset: %d\n", splitk_batch_offset.bscale_k_split_offset);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_sorted_token_ids,
@@ -60,8 +64,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
karg.p_a_scale_grid,
karg.p_b_scale_grid,
karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
p_shared,
karg,
karg.a_element_op,
@@ -101,8 +105,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
karg.p_a_scale_grid,
karg.p_b_scale_grid,
karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
p_shared,
p_shared1,
karg,
@@ -250,13 +254,15 @@ struct GridwiseMoeGemmBlockScale
return 1;
}();
__host__ static auto CalculateGridSize(index_t M, index_t N)
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t K, index_t KBatch)
{
const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
const index_t gridx = NSwizzle ? nblock * mblock : nblock;
const index_t gridy = NSwizzle ? 1 : mblock;
return std::make_tuple(gridx, gridy, 1);
const index_t gridz = KBatch == 1 ? 1 : math::integer_divide_ceil(K, KPerBlock * KBatch);
return std::make_tuple(gridx, gridy, gridz);
}
__host__ __device__ static auto CalculateMPadded(index_t M)
@@ -285,27 +291,31 @@ struct GridwiseMoeGemmBlockScale
__host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
// auto K_t = K_Batch * KPerBlock;
// return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
return K_Batch == 1 ? K / AK1Value : K_Batch * KPerBlock / AK1Value;
}
__host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
// auto K_t = K_Batch * KPerBlock;
// return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
return K_Batch == 1 ? K / BK1Value : K_Batch * KPerBlock / BK1Value;
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * KPerBlock;
// auto K_t = K_Batch * KPerBlock;
// return (K + K_t - 1) / K_t * KPerBlock;
return K_Batch == 1 ? K : K_Batch * KPerBlock;
}
__host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
{
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
auto K_t = K_Batch * KReadVec;
return (K + K_t - 1) / K_t * KReadVec;
// auto K_t = K_Batch * KReadVec;
// return (K + K_t - 1) / K_t * KReadVec;
return K_Batch == 1 ? math::integer_divide_ceil(K, KReadVec) * KReadVec : K_Batch * KPerBlock;
}
__host__ __device__ static auto CalculateMBlock(index_t M)
@@ -410,7 +420,6 @@ struct GridwiseMoeGemmBlockScale
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
@@ -758,19 +767,22 @@ struct GridwiseMoeGemmBlockScale
// KPack * NLane * KLane * K0 * N0
b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
}
if(k_id < karg.KBatch - 1)
{
karg.K = karg.KRead;
}
else
{
karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
}
ascale_k_split_offset = math::integer_divide_ceil(a_k_split_offset, ScaleBlockK);
bscale_k_split_offset = math::integer_divide_ceil(b_k_split_offset, ScaleBlockK);
// if(k_id < karg.KBatch - 1)
// {
// karg.K = karg.KRead;
// }
// else
// {
// karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
// }
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t ascale_k_split_offset;
index_t bscale_k_split_offset;
};
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -1191,7 +1203,7 @@ struct GridwiseMoeGemmBlockScale
CElementwiseOperation c_element_op)
{
ignore = b_element_op;
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1));
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
@@ -1206,7 +1218,7 @@ struct GridwiseMoeGemmBlockScale
IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
problem.MPadded,
problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
problem.NPadded,
problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
problem.StrideC);
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
@@ -1372,9 +1384,8 @@ struct GridwiseMoeGemmBlockScale
decltype(c_thread_buf) c_thread_buf_up;
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
problem.KBatch == 1 ?(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock : problem.KBatch);
constexpr index_t ScaleSliceSizeM = MXdlPerWave;
constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
@@ -1940,7 +1951,7 @@ struct GridwiseMoeGemmBlockScale
CElementwiseOperation c_element_op)
{
ignore = b_element_op;
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1));
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
@@ -1955,7 +1966,7 @@ struct GridwiseMoeGemmBlockScale
IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
problem.MPadded,
problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
problem.NPadded,
problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
problem.StrideC);
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
@@ -2126,8 +2137,8 @@ struct GridwiseMoeGemmBlockScale
decltype(c_thread_buf) c_thread_buf_up;
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
problem.KBatch == 1 ?(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock : problem.KBatch);
// scale
constexpr index_t ScaleSliceSizeM = MXdlPerWave;