set 16x16

This commit is contained in:
coderfeli
2025-04-25 03:09:53 +00:00
parent 2054e165bc
commit f9c29b5ec7
2 changed files with 23 additions and 26 deletions

View File

@@ -1681,7 +1681,8 @@ struct GridwiseMoeGemm
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
// static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
if(expert_block_id * MPerBlock >= max_token_id)
return;
@@ -1690,12 +1691,13 @@ struct GridwiseMoeGemm
const auto block_mn = [&]() -> std::pair<int, int> {
if constexpr(NSwizzle)
{
const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
const index_t prefix_block = ecnt_prefix * problem.NBlock;
const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
const index_t bid_new = blockIdx.x - prefix_block;
const index_t nid = __builtin_amdgcn_readfirstlane(
const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
const index_t prefix_block = ecnt_prefix * problem.NBlock;
const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
const index_t expert_swizzle =
ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
const index_t bid_new = blockIdx.x - prefix_block;
const index_t nid = __builtin_amdgcn_readfirstlane(
bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
const index_t mid =
__builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
@@ -1708,7 +1710,6 @@ struct GridwiseMoeGemm
}();
const index_t block_n_id = block_mn.first;
const index_t block_m_id = block_mn.second;
const index_t token0 =
__builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
@@ -1720,11 +1721,9 @@ struct GridwiseMoeGemm
constexpr auto AMRepeats = MPerBlock / AMThreads;
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
token0 >= problem.NumTokens)
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<index_t, AMRepeats>
gather_offsets; //= p_sorted_token_ids[token_pos];
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets;
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
@@ -2083,8 +2082,7 @@ struct GridwiseMoeGemm
const float* p_sorted_weights_0 = p_ds_grid[I0];
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<index_t, EMRepeats>
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets;
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
auto dstidx = sfc_cde_block.GetIndex(access_id);