mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
tidy code
This commit is contained in:
@@ -131,8 +131,6 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
|
||||
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
|
||||
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
|
||||
// std::cout << "topk_id:" << topk_ids_host << std::endl;
|
||||
// std::cout << "local_expert_masking:" << local_expert_masking_host << std::endl;
|
||||
|
||||
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes());
|
||||
@@ -177,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
warmup,
|
||||
repeat};
|
||||
auto ms = moe_sorting(trait, karg, sc);
|
||||
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ",
|
||||
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ",
|
||||
index_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
tokens,
|
||||
num_experts,
|
||||
topk,
|
||||
ms);
|
||||
topk);
|
||||
|
||||
if(local_expert_masking)
|
||||
{
|
||||
printf("local_eid:%s, ", args.get_str("local_eid").c_str());
|
||||
}
|
||||
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
else
|
||||
printf("ms:%f, ", ms);
|
||||
fflush(stdout);
|
||||
if(ms < 0)
|
||||
{
|
||||
@@ -221,19 +226,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
local_expert_masking);
|
||||
rtn &= ck_tile::check_err(
|
||||
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
|
||||
// std::cout << "sorted_ids_ref:"<<sorted_ids_ref<<std::endl;
|
||||
rtn &= ck_tile::check_err(sorted_weights_host,
|
||||
sorted_weights_ref,
|
||||
std::string("OUT Error: Incorrect w!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
// std::cout << "sorted_weights_ref:"<<sorted_weights_ref<<std::endl;
|
||||
rtn &= ck_tile::check_err(sorted_expert_ids_host,
|
||||
sorted_expert_ids_ref,
|
||||
std::string("OUT Error: Incorrect eid!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
// std::cout << "sorted_expert_ids_ref:"<<sorted_expert_ids_ref<<std::endl;
|
||||
if(moe_buf_size)
|
||||
{
|
||||
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
|
||||
|
||||
@@ -22,3 +22,7 @@ $EXE -t=64 -e=455 -k=8
|
||||
$EXE -t=777 -e=802 -k=99
|
||||
$EXE -t=4097 -e=906 -k=51
|
||||
$EXE -t=128 -e=32 -k=5 -moe_buf_size=262144
|
||||
$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11
|
||||
$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19
|
||||
$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33
|
||||
$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129
|
||||
|
||||
@@ -27,12 +27,12 @@
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/numeric/null_type.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/tensor/buffer_view.hpp"
|
||||
|
||||
@@ -104,7 +104,6 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
|
||||
|
||||
for(index_t s = 0; s < expert_slices[e]; s++)
|
||||
{
|
||||
// out_expert_id[s] = e;
|
||||
out_expert_id[s] = curr_expert_id;
|
||||
unit_cnt++;
|
||||
}
|
||||
|
||||
@@ -351,7 +351,6 @@ struct MoeSortingKernel
|
||||
bound_ctrl)); // row_newbcast:7
|
||||
|
||||
data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
|
||||
// printf("[%d]eid:%d, thread_data:%d, xxx:%d, yyy:%d (%d)\n", threadIdx.x, threadIdx.x/8, thread_data, xxx, yyy, (__lane_id() / 8) % 2);
|
||||
thread_data = thread_data - yyy;
|
||||
#endif
|
||||
|
||||
@@ -683,12 +682,9 @@ struct MoeSortingKernel
|
||||
index_t* p_total_tokens_post_pad,
|
||||
const index_t num_experts,
|
||||
const index_t tokens,
|
||||
// const index_t tokens_per_thread,
|
||||
// const index_t numel,
|
||||
const mdiv unit_size_mdiv,
|
||||
const mdiv topk_mdiv,
|
||||
const mdiv expert_mdiv,
|
||||
// const mdiv sub_tokens_mdiv,
|
||||
const index_t smem_rows,
|
||||
void* smem) const
|
||||
{
|
||||
@@ -701,18 +697,12 @@ struct MoeSortingKernel
|
||||
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
|
||||
|
||||
const index_t smem_cols = num_experts + 1;
|
||||
// const index_t total_smem_tokens_pixel = sub_tokens * num_experts; // no need consider
|
||||
// padding
|
||||
#if 0
|
||||
simple_smem_indexer smem_tokens{reinterpret_cast<index_t*>(smem), smem_cols};
|
||||
simple_smem_indexer smem_cumsum{reinterpret_cast<index_t*>(smem) + sub_tokens * smem_cols};
|
||||
simple_smem_indexer smem_cumdup{reinterpret_cast<index_t*>(smem) + sub_tokens * smem_cols + smem_cols};
|
||||
#else
|
||||
|
||||
simple_smem_indexer smem_cumsum{reinterpret_cast<index_t*>(smem) + 0};
|
||||
simple_smem_indexer smem_cumdup{reinterpret_cast<index_t*>(smem) + smem_cols};
|
||||
simple_smem_indexer smem_tokens{reinterpret_cast<index_t*>(smem) + 2 * smem_cols,
|
||||
smem_cols};
|
||||
#endif
|
||||
|
||||
// #pragma unroll 8
|
||||
for(int i = tid; i < (sub_tokens * num_experts); i += block_size)
|
||||
{
|
||||
@@ -733,13 +723,8 @@ struct MoeSortingKernel
|
||||
|
||||
if(i_t < tokens)
|
||||
{
|
||||
|
||||
int eid = topk_id[i_t * topk + curr_topk_id];
|
||||
|
||||
// printf("eid:%d, [%d] tid:%d, (i_token:%d, curr_token_id:%d)i_t:%d,
|
||||
// curr_topk_id:%d, tokens:%d\n",
|
||||
// eid, i, tid, i_token, curr_token_id, i_t, curr_topk_id, tokens);
|
||||
// smem_tokens(curr_token_id, eid)++;
|
||||
if constexpr(Problem::SubTokenOneShot)
|
||||
smem_tokens(curr_token_id, eid) = curr_topk_id + 1;
|
||||
else
|
||||
@@ -748,27 +733,6 @@ struct MoeSortingKernel
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
}
|
||||
__syncthreads(); // make sure different i_token iteration not overlap by different wave
|
||||
// if(tid == 0) {
|
||||
// int e0 = smem_tokens(0, 0);
|
||||
// int e1 = smem_tokens(1, 0);
|
||||
// int e2 = smem_tokens(2, 0);
|
||||
// int e3 = smem_tokens(3, 0);
|
||||
// int e4 = smem_tokens(4, 0);
|
||||
// int e5 = smem_tokens(5, 0);
|
||||
// int e6 = smem_tokens(6, 0);
|
||||
// int e7 = smem_tokens(7, 0);
|
||||
// printf("xxx eid:%d i_token:%d, cnt:%d,%d,%d,%d,%d,%d,%d,%d(%d)\n", 0, i_token,
|
||||
// e0,
|
||||
// e1,
|
||||
// e2,
|
||||
// e3,
|
||||
// e4,
|
||||
// e5,
|
||||
// e6,
|
||||
// e7,
|
||||
// e0+e1+e2+e3+e4+e5+e6+e7
|
||||
// );
|
||||
// }
|
||||
}
|
||||
|
||||
// counting
|
||||
@@ -777,36 +741,7 @@ struct MoeSortingKernel
|
||||
smem_cumsum(0) = 0;
|
||||
// smem_cumdup(0) = 0;
|
||||
}
|
||||
#if 0
|
||||
(void)f_sum;
|
||||
for(int i_e = tid; i_e < num_experts; i_e += block_size)
|
||||
{
|
||||
index_t local_c[8];
|
||||
index_t cnt = 0;
|
||||
// TODO: manually unroll. pragma unroll does not work well when we have dependency
|
||||
for(int i = 0; i < sub_tokens; i += 8)
|
||||
{
|
||||
local_c[0] = smem_tokens(i + 0, i_e);
|
||||
local_c[1] = smem_tokens(i + 1, i_e);
|
||||
local_c[2] = smem_tokens(i + 2, i_e);
|
||||
local_c[3] = smem_tokens(i + 3, i_e);
|
||||
local_c[4] = smem_tokens(i + 4, i_e);
|
||||
local_c[5] = smem_tokens(i + 5, i_e);
|
||||
local_c[6] = smem_tokens(i + 6, i_e);
|
||||
local_c[7] = smem_tokens(i + 7, i_e);
|
||||
|
||||
cnt += local_c[0];
|
||||
cnt += local_c[1];
|
||||
cnt += local_c[2];
|
||||
cnt += local_c[3];
|
||||
cnt += local_c[4];
|
||||
cnt += local_c[5];
|
||||
cnt += local_c[6];
|
||||
cnt += local_c[7];
|
||||
}
|
||||
smem_cumsum(i_e + 1) = cnt;
|
||||
}
|
||||
#else
|
||||
{
|
||||
constexpr int lane_group_sz = 8;
|
||||
int lane_group_id = tid / lane_group_sz;
|
||||
@@ -835,61 +770,11 @@ struct MoeSortingKernel
|
||||
{
|
||||
cnt += wave_reduce(local_c[j], f_sum, number<8>{});
|
||||
}
|
||||
|
||||
// if constexpr(Problem::SubTokenTile == 2)
|
||||
// printf("i_e:%d, lane_group_os:%d -> %d, %d\n",
|
||||
// i_e, lane_group_os,
|
||||
// local_c[0],
|
||||
// local_c[1]);
|
||||
//
|
||||
// printf("i_e:%d, lane_group_os:%d, %d, %d, %d, %d, %d, %d, %d, %d\n",
|
||||
// i_e, lane_group_os,
|
||||
// local_c[0],
|
||||
// local_c[1],
|
||||
// local_c[2],
|
||||
// local_c[3],
|
||||
// local_c[4],
|
||||
// local_c[5],
|
||||
// local_c[6],
|
||||
// local_c[7]);
|
||||
#if 0
|
||||
#if 1
|
||||
cnt +=
|
||||
(i + 0 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[0], f_sum, number<8>{});
|
||||
cnt +=
|
||||
(i + 1 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[1], f_sum, number<8>{});
|
||||
cnt +=
|
||||
(i + 2 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[2], f_sum, number<8>{});
|
||||
cnt +=
|
||||
(i + 3 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[3], f_sum, number<8>{});
|
||||
cnt +=
|
||||
(i + 4 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[4], f_sum, number<8>{});
|
||||
cnt +=
|
||||
(i + 5 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[5], f_sum, number<8>{});
|
||||
cnt +=
|
||||
(i + 6 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[6], f_sum, number<8>{});
|
||||
cnt +=
|
||||
(i + 7 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[7], f_sum, number<8>{});
|
||||
#else
|
||||
// TODO: this rely on LDS OOB behavior, too hardware specific
|
||||
cnt += wave_reduce(local_c[0], f_sum, number<8>{});
|
||||
cnt += wave_reduce(local_c[1], f_sum, number<8>{});
|
||||
cnt += wave_reduce(local_c[2], f_sum, number<8>{});
|
||||
cnt += wave_reduce(local_c[3], f_sum, number<8>{});
|
||||
cnt += wave_reduce(local_c[4], f_sum, number<8>{});
|
||||
cnt += wave_reduce(local_c[5], f_sum, number<8>{});
|
||||
cnt += wave_reduce(local_c[6], f_sum, number<8>{});
|
||||
cnt += wave_reduce(local_c[7], f_sum, number<8>{});
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
if(lane_group_os == 0)
|
||||
smem_cumsum(i_e + 1) = cnt;
|
||||
|
||||
// printf("i_e:%d, cnt:%d\n", i_e, cnt);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
if constexpr(Problem::LocalExpertMasking)
|
||||
{
|
||||
@@ -897,43 +782,22 @@ struct MoeSortingKernel
|
||||
for(int i_e = tid; i_e < num_experts; i_e += block_size)
|
||||
{
|
||||
// reuse this buffer
|
||||
// printf("tid:%d, m:%d\n", tid, local_expert_mask[i_e]);
|
||||
smem_cumdup(i_e + 1) = local_expert_mask[i_e];
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
#if 0
|
||||
if(tid == 0)
|
||||
{
|
||||
(void)lid;
|
||||
(void)wid;
|
||||
for(int i = 1; i <= num_experts; ++i)
|
||||
{
|
||||
// printf("e:%d -- %d (%d) \n", i - 1, smem_cumsum(i), sub_tokens);
|
||||
auto current_units = [&]() {
|
||||
index_t x_ = smem_cumsum(i) + unit_size_mdiv.divisor - 1;
|
||||
index_t y_ = unit_size_mdiv.div(x_);
|
||||
return max(y_, 1) * unit_size_mdiv.divisor;
|
||||
}();
|
||||
smem_cumsum(i) = smem_cumsum(i - 1) + current_units;
|
||||
}
|
||||
*p_total_tokens_post_pad = smem_cumsum(num_experts);
|
||||
}
|
||||
__syncthreads();
|
||||
#else
|
||||
|
||||
{
|
||||
if(wid == 0)
|
||||
{
|
||||
// NOTE: under this block can never use __syncthreads!
|
||||
int i_e_ = 0;
|
||||
int local_cumsum_ = 0;
|
||||
// int pre_cumsum_ = 0;
|
||||
for(; i_e_ < num_experts; i_e_ += warpSize)
|
||||
{
|
||||
int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
|
||||
// if((i_e_+lid) < num_experts)
|
||||
int local_cnt = smem_cumsum(i_e_ + lid + 1);
|
||||
int local_cnt = smem_cumsum(i_e_ + lid + 1);
|
||||
int blocks_pers_expert =
|
||||
unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
|
||||
|
||||
@@ -976,10 +840,7 @@ struct MoeSortingKernel
|
||||
// pre_sumsum has value, which will result int
|
||||
// zero local cumsum(but we want at least padded)
|
||||
wave_cumsum<int, warpSize>(local_cumsum_);
|
||||
// printf(" lid:%d(%d), local_cnt:%d,pre_cumsum_:%d, %d--> %d (m:%d)\n", lid,
|
||||
// i_e_ +
|
||||
// lid, local_cnt, pre_cumsum_, padded_tokens_per_expert,local_cumsum_
|
||||
// ,local_masking);
|
||||
|
||||
if((i_e_ + lid) < num_experts)
|
||||
smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
|
||||
|
||||
@@ -1003,7 +864,6 @@ struct MoeSortingKernel
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
#endif
|
||||
|
||||
for(int i_e = tid; i_e < num_experts; i_e += block_size)
|
||||
{
|
||||
@@ -1020,9 +880,6 @@ struct MoeSortingKernel
|
||||
return i_e;
|
||||
}();
|
||||
|
||||
// printf("i_e:%d, e_start:%d, e_end:%d, expert_id:%d (%d-%d, m:%d)\n", i_e, e_start,
|
||||
// e_end, expert_id, e_start, e_end, local_expert_mask[i_e]);
|
||||
|
||||
smem_cumdup(i_e) = e_start; // duplicate cumsum for later use
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
{
|
||||
@@ -1041,7 +898,6 @@ struct MoeSortingKernel
|
||||
p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id;
|
||||
}
|
||||
}
|
||||
// if (tid == 0)
|
||||
smem_cumdup(num_experts) = smem_cumsum(num_experts);
|
||||
|
||||
// fill the p_sorted_token_ids/p_sorted_weights
|
||||
@@ -1068,40 +924,13 @@ struct MoeSortingKernel
|
||||
int i_t = i_token + curr_token_id;
|
||||
if(i_t < tokens)
|
||||
{
|
||||
int eid = topk_id[i_t * topk + curr_topk_id];
|
||||
// if(eid == 0) {
|
||||
// printf("@@@ eid:%d, i_t:%d, cur:%d, curr_topk_id:%d\n", eid, i_t,
|
||||
// curr_token_id, curr_topk_id); printf("## eid:%d,%d\n", i_t,
|
||||
// curr_topk_id);
|
||||
//}
|
||||
int eid = topk_id[i_t * topk + curr_topk_id];
|
||||
smem_tokens(curr_token_id, eid) = curr_topk_id + 1; // at least 1
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
#if 0
|
||||
for(int eid = tid; eid < num_experts; eid += block_size) {
|
||||
// indeed we can unroll 8x
|
||||
for(int i_sub_token = 0; i_sub_token < sub_tokens; i_sub_token++) {
|
||||
auto x = smem_tokens(i_sub_token, eid);
|
||||
//if (eid == 0)
|
||||
// printf("@@ eid:%d, pos:%d, i_sub_token:%d, x:%d\n", eid, smem_cumsum(eid), i_sub_token, x);
|
||||
if(x != 0) {
|
||||
// now x is topk value
|
||||
int position = smem_cumsum(eid);
|
||||
|
||||
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
p_sorted_token_ids[position] = MOE_SORTING_MOCK_ID(i_token + i_sub_token, x - 1);
|
||||
#else
|
||||
p_sorted_token_ids[position] = i_token + i_sub_token;
|
||||
#endif
|
||||
p_sorted_weights[position] = weights[(i_token + i_sub_token) * topk + x - 1];
|
||||
smem_cumsum(eid) = position + 1; // increase position
|
||||
}
|
||||
// __syncthreads();
|
||||
}
|
||||
}
|
||||
#else
|
||||
{
|
||||
constexpr int lane_group_sz = 8;
|
||||
int lane_group_id = tid / lane_group_sz;
|
||||
@@ -1138,17 +967,12 @@ struct MoeSortingKernel
|
||||
|
||||
int remote_cnt = __builtin_amdgcn_ds_bpermute(
|
||||
(lane_group_sz * (lane_group_id + 1) - 1) << 2, local_cnt);
|
||||
// printf("[%d]eid:%d, i_sub_token:%d, position:%d, x:%d, local_cnt:%d(%d),
|
||||
// remote_cnt:%d\n",
|
||||
// tid, eid, i_sub_token, position, x, local_cnt, local_cnt_cache,
|
||||
// remote_cnt);
|
||||
|
||||
position += remote_cnt;
|
||||
}
|
||||
smem_cumsum(eid) = position;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
// (void) weights;
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
@@ -1157,7 +981,6 @@ struct MoeSortingKernel
|
||||
{
|
||||
int e_start = smem_cumsum(eid);
|
||||
int e_end = smem_cumdup(eid + 1);
|
||||
// printf("--- eid:%d, e_start:%d, e_end:%d\n", eid, e_start, e_end);
|
||||
if constexpr(Problem::SkipExpertsWithZeroTokens)
|
||||
{
|
||||
if(e_start == e_end) // skip zero token expert
|
||||
@@ -1201,8 +1024,6 @@ struct MoeSortingKernel
|
||||
static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
|
||||
kargs.num_experts,
|
||||
kargs.tokens,
|
||||
// kargs.tokens_per_thread,
|
||||
// numel,
|
||||
kargs.unit_size_mdiv,
|
||||
kargs.topk_mdiv,
|
||||
kargs.expert_mdiv,
|
||||
|
||||
Reference in New Issue
Block a user