mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
fix nswizzle = true
This commit is contained in:
@@ -139,7 +139,7 @@ static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 32;
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t Nswizzle = true;
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
@@ -199,7 +199,7 @@ int main(int argc, char* argv[])
|
||||
ck::index_t valid_tile_num = 13;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
ck::index_t tokens = 64;
|
||||
ck::index_t tokens = 544;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
// ck::index_t tokens = batch * topk;
|
||||
@@ -245,17 +245,18 @@ int main(int argc, char* argv[])
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 1, 2,2,0,0,0};
|
||||
int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 7, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
// max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0};
|
||||
max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13};
|
||||
int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
for (int i = 0; i < sorted_tile_num; i++) {
|
||||
expert_ids.mData[i] = eids[i];
|
||||
}
|
||||
int token_per_tile = tokens * topk / valid_tile_num;
|
||||
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
// sorted_token_ids.mData[0] = 0;
|
||||
for (int i = 0; i < sorted_size; i++) {
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile)
|
||||
if(tile_off < token_per_tile && tokenid < tokens * topk)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
|
||||
tokenid++;
|
||||
@@ -274,7 +275,6 @@ int main(int argc, char* argv[])
|
||||
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
|
||||
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
|
||||
@@ -287,8 +287,8 @@ int main(int argc, char* argv[])
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
|
||||
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{1, 2});
|
||||
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{1, 2});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
|
||||
@@ -1174,16 +1174,30 @@ struct GridwiseMoeGemm
|
||||
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
|
||||
// constexpr int expert_tile_cnt[8] = {2, 1, 1, 2, 2, 2, 1, 2};
|
||||
// const index_t b_block_id = blockIdx.x % problem.NBlock;
|
||||
const index_t expert_block_id = blockIdx.x / problem.NBlock;
|
||||
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[blockIdx.x / problem.NBlock]);
|
||||
const auto block_mn = [&]() -> std::pair<int, int> {
|
||||
if constexpr (NSwizzle)
|
||||
{
|
||||
const index_t expert_block_id = blockIdx.x / problem.NBlock;
|
||||
const index_t es = __builtin_amdgcn_readfirstlane(p_max_token_id[expert_block_id + 1]);
|
||||
const index_t expert_swizzle = es > 0 ? es : 1; //p_max_token_id[expert_id + 1];
|
||||
const index_t expert_block_swizzle = expert_block_id / expert_swizzle;
|
||||
const index_t b_block_id_swizzle = blockIdx.x % (problem.NBlock * expert_swizzle);
|
||||
const index_t nid = __builtin_amdgcn_readfirstlane(b_block_id_swizzle % 8 + b_block_id_swizzle / (8 * expert_swizzle) * 8);
|
||||
const index_t mid = __builtin_amdgcn_readfirstlane(expert_block_swizzle * expert_swizzle + b_block_id_swizzle / 8 % expert_swizzle);
|
||||
// const index_t expert_block_id = blockIdx.x / problem.NBlock; //
|
||||
// const index_t es = __builtin_amdgcn_readfirstlane(p_max_token_id[expert_block_id + 1]);
|
||||
// const index_t expert_swizzle = es > 0 ? es : 1; //p_max_token_id[expert_id + 1];
|
||||
// const index_t expert_block_swizzle = expert_block_id / expert_swizzle;
|
||||
// const index_t b_block_id_swizzle = blockIdx.x % (problem.NBlock * expert_swizzle);
|
||||
// const index_t nid = __builtin_amdgcn_readfirstlane(b_block_id_swizzle % 8 + b_block_id_swizzle / (8 * expert_swizzle) * 8);
|
||||
// const index_t mid = __builtin_amdgcn_readfirstlane(expert_block_swizzle * expert_swizzle + b_block_id_swizzle / 8 % expert_swizzle);
|
||||
// if(threadIdx.x==0)
|
||||
// printf("block, %d, mid, %d, nid, %d, ecnt, %d, expert %d \n", blockIdx.x, mid, nid, es, p_sorted_expert_ids[expert_block_id]);
|
||||
|
||||
const index_t ecnt_prefix = p_max_token_id[1+expert_id];
|
||||
const index_t prefix_block = ecnt_prefix * problem.NBlock;
|
||||
const index_t ecnt = p_max_token_id[2+expert_id] - ecnt_prefix;
|
||||
const index_t expert_swizzle = ecnt > 0 ? ecnt : 1; //p_max_token_id[expert_id + 1]; // 2
|
||||
const index_t bid_new = blockIdx.x - prefix_block;
|
||||
const index_t nid = __builtin_amdgcn_readfirstlane(bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
|
||||
const index_t mid = __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
|
||||
// if(threadIdx.x==0)
|
||||
// printf("block, %d, mid, %d, nid, %d, ecnt, %d, expert %d \n", blockIdx.x, mid, nid, ecnt, expert_id);
|
||||
return {nid, mid};
|
||||
} else {
|
||||
return {blockIdx.x, blockIdx.y};
|
||||
@@ -1191,7 +1205,6 @@ struct GridwiseMoeGemm
|
||||
}();
|
||||
const index_t block_n_id = block_mn.first;
|
||||
const index_t block_m_id = block_mn.second;
|
||||
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]);
|
||||
// if (threadIdx.x==0) {
|
||||
// printf("bid %d, eid %d, es %d, esi %d, bsi %d, m %d, n %d\n", blockIdx.x, expert_id, expert_swizzle, expert_block_swizzle, b_block_id_swizzle, block_m_id, block_n_id);
|
||||
// }
|
||||
@@ -1205,7 +1218,7 @@ struct GridwiseMoeGemm
|
||||
constexpr auto AMRepeats = MPerBlock / AMThreads;
|
||||
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
|
||||
|
||||
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
|
||||
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || token0 >= problem.NumTokens)
|
||||
return;
|
||||
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
|
||||
static_for<0, AMRepeats, 1>{}([&](auto m0) {
|
||||
|
||||
Reference in New Issue
Block a user