mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
impl gemm2 swizzle
This commit is contained in:
@@ -236,7 +236,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
|
||||
// const ck::index_t experts = 8;
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({experts}, {1}));
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
|
||||
max_token_id.mData[0] = valid_size;
|
||||
|
||||
@@ -197,8 +197,11 @@ struct GridwiseMoeGemmScatter
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
return std::make_tuple(math::integer_divide_ceil(N, NPerBlock),
|
||||
math::integer_divide_ceil(M, MPerBlock),
|
||||
// return std::make_tuple(math::integer_divide_ceil(N, NPerBlock),
|
||||
// math::integer_divide_ceil(M, MPerBlock),
|
||||
// 1);
|
||||
return std::make_tuple(math::integer_divide_ceil(N, NPerBlock) * math::integer_divide_ceil(M, MPerBlock),
|
||||
1,
|
||||
1);
|
||||
}
|
||||
|
||||
@@ -1149,10 +1152,22 @@ struct GridwiseMoeGemmScatter
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]);
|
||||
|
||||
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
|
||||
const index_t expert_block_id = blockIdx.x / problem.NBlock;
|
||||
// const index_t b_block_id = blockIdx.x % problem.NBlock;
|
||||
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
|
||||
const index_t es = __builtin_amdgcn_readfirstlane(p_max_token_id[expert_block_id + 1]);
|
||||
const index_t expert_swizzle = es > 0 ? es : 1; //p_max_token_id[expert_id + 1];
|
||||
const index_t expert_block_swizzle = expert_block_id / expert_swizzle;
|
||||
const index_t b_block_id_swizzle = blockIdx.x % (problem.NBlock * expert_swizzle);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(b_block_id_swizzle % 8 + b_block_id_swizzle / (8 * expert_swizzle) * 8);
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(expert_block_swizzle * expert_swizzle + b_block_id_swizzle / 8 % expert_swizzle);
|
||||
|
||||
// const index_t block_n_id = __builtin_amdgcn_readfirstlane(blockIdx.x);
|
||||
// const index_t block_m_id = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
// const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]);
|
||||
// const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
|
||||
const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
|
||||
|
||||
// constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
|
||||
|
||||
Reference in New Issue
Block a user