mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
splitk hack pass
This commit is contained in:
@@ -185,16 +185,21 @@ int main(int argc, char* argv[])
|
||||
// GEMM shape
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 6144;
|
||||
// ck::index_t N = 128;
|
||||
// ck::index_t K = 512;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t topk = 2;
|
||||
// ck::index_t sorted_tile_num = 515;
|
||||
// ck::index_t valid_tile_num = 512;
|
||||
// ck::index_t tokens = 8192;
|
||||
// ck::index_t tokens = 208;
|
||||
// ck::index_t sorted_tile_num = 15;
|
||||
// ck::index_t valid_tile_num = 13;
|
||||
ck::index_t sorted_tile_num = 259;
|
||||
ck::index_t valid_tile_num = 256;
|
||||
ck::index_t tokens = 4096;
|
||||
// ck::index_t sorted_tile_num = 259;
|
||||
// ck::index_t valid_tile_num = 256;
|
||||
// ck::index_t tokens = 4096;
|
||||
ck::index_t sorted_tile_num = 2;
|
||||
ck::index_t valid_tile_num = 2;
|
||||
ck::index_t tokens = 32;
|
||||
#else
|
||||
// deepseek
|
||||
ck::index_t N = 2048;
|
||||
@@ -256,14 +261,14 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
ck::index_t StrideE = N * 2;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0};
|
||||
ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N * 2;
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
ck::index_t KBatch = 6;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
@@ -319,9 +324,9 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-1.0, 1.0});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-1.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
|
||||
@@ -203,12 +203,11 @@ struct DeviceMoeGemmBlockScale
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N * (IsInputGemm && IsSplitK ? 2 : 1), arg.K, arg.KBatch);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
index_t K_split = arg.KBatch == 1 ? arg.K : arg.KBatch * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
const auto RunKernel = [&](const auto& kernel) {
|
||||
@@ -443,6 +442,11 @@ struct DeviceMoeGemmBlockScale
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if (arg.KBatch > 1 && arg.K % (KPerBlock * arg.KBatch) != 0)
|
||||
{
|
||||
// Not support Kpadding with KBatch > 1
|
||||
return false;
|
||||
}
|
||||
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
|
||||
@@ -51,6 +51,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
// printf("splitk_batch_offset.a_k_split_offset: %d\n", splitk_batch_offset.a_k_split_offset);
|
||||
// printf("splitk_batch_offset.b_k_split_offset: %d\n", splitk_batch_offset.b_k_split_offset);
|
||||
// printf("splitk_batch_offset.ascale_k_split_offset: %d\n", splitk_batch_offset.ascale_k_split_offset);
|
||||
// printf("splitk_batch_offset.bscale_k_split_offset: %d\n", splitk_batch_offset.bscale_k_split_offset);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_sorted_token_ids,
|
||||
@@ -60,8 +64,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
karg.p_a_scale_grid,
|
||||
karg.p_b_scale_grid,
|
||||
karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
@@ -101,8 +105,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
karg.p_a_scale_grid,
|
||||
karg.p_b_scale_grid,
|
||||
karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
|
||||
p_shared,
|
||||
p_shared1,
|
||||
karg,
|
||||
@@ -250,13 +254,15 @@ struct GridwiseMoeGemmBlockScale
|
||||
return 1;
|
||||
}();
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N)
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t K, index_t KBatch)
|
||||
{
|
||||
const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
|
||||
const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
|
||||
const index_t gridx = NSwizzle ? nblock * mblock : nblock;
|
||||
const index_t gridy = NSwizzle ? 1 : mblock;
|
||||
return std::make_tuple(gridx, gridy, 1);
|
||||
const index_t gridz = KBatch == 1 ? 1 : math::integer_divide_ceil(K, KPerBlock * KBatch);
|
||||
|
||||
return std::make_tuple(gridx, gridy, gridz);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateMPadded(index_t M)
|
||||
@@ -285,27 +291,31 @@ struct GridwiseMoeGemmBlockScale
|
||||
|
||||
__host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
|
||||
// auto K_t = K_Batch * KPerBlock;
|
||||
// return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
|
||||
return K_Batch == 1 ? K / AK1Value : K_Batch * KPerBlock / AK1Value;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
|
||||
// auto K_t = K_Batch * KPerBlock;
|
||||
// return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
|
||||
return K_Batch == 1 ? K / BK1Value : K_Batch * KPerBlock / BK1Value;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * KPerBlock;
|
||||
// auto K_t = K_Batch * KPerBlock;
|
||||
// return (K + K_t - 1) / K_t * KPerBlock;
|
||||
return K_Batch == 1 ? K : K_Batch * KPerBlock;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
|
||||
auto K_t = K_Batch * KReadVec;
|
||||
return (K + K_t - 1) / K_t * KReadVec;
|
||||
// auto K_t = K_Batch * KReadVec;
|
||||
// return (K + K_t - 1) / K_t * KReadVec;
|
||||
return K_Batch == 1 ? math::integer_divide_ceil(K, KReadVec) * KReadVec : K_Batch * KPerBlock;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateMBlock(index_t M)
|
||||
@@ -410,7 +420,6 @@ struct GridwiseMoeGemmBlockScale
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
}
|
||||
@@ -758,19 +767,22 @@ struct GridwiseMoeGemmBlockScale
|
||||
// KPack * NLane * KLane * K0 * N0
|
||||
b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
|
||||
}
|
||||
|
||||
if(k_id < karg.KBatch - 1)
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
}
|
||||
else
|
||||
{
|
||||
karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
|
||||
}
|
||||
ascale_k_split_offset = math::integer_divide_ceil(a_k_split_offset, ScaleBlockK);
|
||||
bscale_k_split_offset = math::integer_divide_ceil(b_k_split_offset, ScaleBlockK);
|
||||
// if(k_id < karg.KBatch - 1)
|
||||
// {
|
||||
// karg.K = karg.KRead;
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
|
||||
// }
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t ascale_k_split_offset;
|
||||
index_t bscale_k_split_offset;
|
||||
};
|
||||
|
||||
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -1191,7 +1203,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
ignore = b_element_op;
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1));
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
|
||||
@@ -1206,7 +1218,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
|
||||
problem.MPadded,
|
||||
problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
problem.NPadded,
|
||||
problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
problem.StrideC);
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
@@ -1372,9 +1384,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
decltype(c_thread_buf) c_thread_buf_up;
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
problem.KBatch == 1 ?(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock : problem.KBatch);
|
||||
constexpr index_t ScaleSliceSizeM = MXdlPerWave;
|
||||
constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
|
||||
constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
|
||||
@@ -1940,7 +1951,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
ignore = b_element_op;
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1));
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
|
||||
@@ -1955,7 +1966,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
|
||||
problem.MPadded,
|
||||
problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
problem.NPadded,
|
||||
problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
problem.StrideC);
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
@@ -2126,8 +2137,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
decltype(c_thread_buf) c_thread_buf_up;
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
problem.KBatch == 1 ?(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock : problem.KBatch);
|
||||
|
||||
// scale
|
||||
constexpr index_t ScaleSliceSizeM = MXdlPerWave;
|
||||
|
||||
Reference in New Issue
Block a user