tidy code

This commit is contained in:
carlushuang
2025-02-10 18:30:47 +08:00
parent 8b3e32a96d
commit a87dfaddbf
5 changed files with 22 additions and 196 deletions

View File

@@ -131,8 +131,6 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
// std::cout << "topk_id:" << topk_ids_host << std::endl;
// std::cout << "local_expert_masking:" << local_expert_masking_host << std::endl;
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes());
@@ -177,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args)
warmup,
repeat};
auto ms = moe_sorting(trait, karg, sc);
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ",
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ",
index_prec.c_str(),
weight_prec.c_str(),
tokens,
num_experts,
topk,
ms);
topk);
if(local_expert_masking)
{
printf("local_eid:%s, ", args.get_str("local_eid").c_str());
}
if(ms < 0)
printf("not supported\n");
else
printf("ms:%f, ", ms);
fflush(stdout);
if(ms < 0)
{
@@ -221,19 +226,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
local_expert_masking);
rtn &= ck_tile::check_err(
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
// std::cout << "sorted_ids_ref:"<<sorted_ids_ref<<std::endl;
rtn &= ck_tile::check_err(sorted_weights_host,
sorted_weights_ref,
std::string("OUT Error: Incorrect w!"),
1e-6,
1e-6);
// std::cout << "sorted_weights_ref:"<<sorted_weights_ref<<std::endl;
rtn &= ck_tile::check_err(sorted_expert_ids_host,
sorted_expert_ids_ref,
std::string("OUT Error: Incorrect eid!"),
1e-6,
1e-6);
// std::cout << "sorted_expert_ids_ref:"<<sorted_expert_ids_ref<<std::endl;
if(moe_buf_size)
{
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});

View File

@@ -22,3 +22,7 @@ $EXE -t=64 -e=455 -k=8
$EXE -t=777 -e=802 -k=99
$EXE -t=4097 -e=906 -k=51
$EXE -t=128 -e=32 -k=5 -moe_buf_size=262144
$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11
$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19
$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33
$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129

View File

@@ -27,12 +27,12 @@
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/int8.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/null_type.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp"

View File

@@ -104,7 +104,6 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
for(index_t s = 0; s < expert_slices[e]; s++)
{
// out_expert_id[s] = e;
out_expert_id[s] = curr_expert_id;
unit_cnt++;
}

View File

@@ -351,7 +351,6 @@ struct MoeSortingKernel
bound_ctrl)); // row_newbcast:7
data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
// printf("[%d]eid:%d, thread_data:%d, xxx:%d, yyy:%d (%d)\n", threadIdx.x, threadIdx.x/8, thread_data, xxx, yyy, (__lane_id() / 8) % 2);
thread_data = thread_data - yyy;
#endif
@@ -683,12 +682,9 @@ struct MoeSortingKernel
index_t* p_total_tokens_post_pad,
const index_t num_experts,
const index_t tokens,
// const index_t tokens_per_thread,
// const index_t numel,
const mdiv unit_size_mdiv,
const mdiv topk_mdiv,
const mdiv expert_mdiv,
// const mdiv sub_tokens_mdiv,
const index_t smem_rows,
void* smem) const
{
@@ -701,18 +697,12 @@ struct MoeSortingKernel
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
const index_t smem_cols = num_experts + 1;
// const index_t total_smem_tokens_pixel = sub_tokens * num_experts; // no need consider
// padding
#if 0
simple_smem_indexer smem_tokens{reinterpret_cast<index_t*>(smem), smem_cols};
simple_smem_indexer smem_cumsum{reinterpret_cast<index_t*>(smem) + sub_tokens * smem_cols};
simple_smem_indexer smem_cumdup{reinterpret_cast<index_t*>(smem) + sub_tokens * smem_cols + smem_cols};
#else
simple_smem_indexer smem_cumsum{reinterpret_cast<index_t*>(smem) + 0};
simple_smem_indexer smem_cumdup{reinterpret_cast<index_t*>(smem) + smem_cols};
simple_smem_indexer smem_tokens{reinterpret_cast<index_t*>(smem) + 2 * smem_cols,
smem_cols};
#endif
// #pragma unroll 8
for(int i = tid; i < (sub_tokens * num_experts); i += block_size)
{
@@ -733,13 +723,8 @@ struct MoeSortingKernel
if(i_t < tokens)
{
int eid = topk_id[i_t * topk + curr_topk_id];
// printf("eid:%d, [%d] tid:%d, (i_token:%d, curr_token_id:%d)i_t:%d,
// curr_topk_id:%d, tokens:%d\n",
// eid, i, tid, i_token, curr_token_id, i_t, curr_topk_id, tokens);
// smem_tokens(curr_token_id, eid)++;
if constexpr(Problem::SubTokenOneShot)
smem_tokens(curr_token_id, eid) = curr_topk_id + 1;
else
@@ -748,27 +733,6 @@ struct MoeSortingKernel
__builtin_amdgcn_s_waitcnt(0xc07f);
}
__syncthreads(); // make sure different i_token iteration not overlap by different wave
// if(tid == 0) {
// int e0 = smem_tokens(0, 0);
// int e1 = smem_tokens(1, 0);
// int e2 = smem_tokens(2, 0);
// int e3 = smem_tokens(3, 0);
// int e4 = smem_tokens(4, 0);
// int e5 = smem_tokens(5, 0);
// int e6 = smem_tokens(6, 0);
// int e7 = smem_tokens(7, 0);
// printf("xxx eid:%d i_token:%d, cnt:%d,%d,%d,%d,%d,%d,%d,%d(%d)\n", 0, i_token,
// e0,
// e1,
// e2,
// e3,
// e4,
// e5,
// e6,
// e7,
// e0+e1+e2+e3+e4+e5+e6+e7
// );
// }
}
// counting
@@ -777,36 +741,7 @@ struct MoeSortingKernel
smem_cumsum(0) = 0;
// smem_cumdup(0) = 0;
}
#if 0
(void)f_sum;
for(int i_e = tid; i_e < num_experts; i_e += block_size)
{
index_t local_c[8];
index_t cnt = 0;
// TODO: manually unroll. pragma unroll does not work well when we have dependency
for(int i = 0; i < sub_tokens; i += 8)
{
local_c[0] = smem_tokens(i + 0, i_e);
local_c[1] = smem_tokens(i + 1, i_e);
local_c[2] = smem_tokens(i + 2, i_e);
local_c[3] = smem_tokens(i + 3, i_e);
local_c[4] = smem_tokens(i + 4, i_e);
local_c[5] = smem_tokens(i + 5, i_e);
local_c[6] = smem_tokens(i + 6, i_e);
local_c[7] = smem_tokens(i + 7, i_e);
cnt += local_c[0];
cnt += local_c[1];
cnt += local_c[2];
cnt += local_c[3];
cnt += local_c[4];
cnt += local_c[5];
cnt += local_c[6];
cnt += local_c[7];
}
smem_cumsum(i_e + 1) = cnt;
}
#else
{
constexpr int lane_group_sz = 8;
int lane_group_id = tid / lane_group_sz;
@@ -835,61 +770,11 @@ struct MoeSortingKernel
{
cnt += wave_reduce(local_c[j], f_sum, number<8>{});
}
// if constexpr(Problem::SubTokenTile == 2)
// printf("i_e:%d, lane_group_os:%d -> %d, %d\n",
// i_e, lane_group_os,
// local_c[0],
// local_c[1]);
//
// printf("i_e:%d, lane_group_os:%d, %d, %d, %d, %d, %d, %d, %d, %d\n",
// i_e, lane_group_os,
// local_c[0],
// local_c[1],
// local_c[2],
// local_c[3],
// local_c[4],
// local_c[5],
// local_c[6],
// local_c[7]);
#if 0
#if 1
cnt +=
(i + 0 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[0], f_sum, number<8>{});
cnt +=
(i + 1 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[1], f_sum, number<8>{});
cnt +=
(i + 2 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[2], f_sum, number<8>{});
cnt +=
(i + 3 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[3], f_sum, number<8>{});
cnt +=
(i + 4 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[4], f_sum, number<8>{});
cnt +=
(i + 5 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[5], f_sum, number<8>{});
cnt +=
(i + 6 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[6], f_sum, number<8>{});
cnt +=
(i + 7 * 8 >= sub_tokens) ? 0 : wave_reduce(local_c[7], f_sum, number<8>{});
#else
// TODO: this rely on LDS OOB behavior, too hardware specific
cnt += wave_reduce(local_c[0], f_sum, number<8>{});
cnt += wave_reduce(local_c[1], f_sum, number<8>{});
cnt += wave_reduce(local_c[2], f_sum, number<8>{});
cnt += wave_reduce(local_c[3], f_sum, number<8>{});
cnt += wave_reduce(local_c[4], f_sum, number<8>{});
cnt += wave_reduce(local_c[5], f_sum, number<8>{});
cnt += wave_reduce(local_c[6], f_sum, number<8>{});
cnt += wave_reduce(local_c[7], f_sum, number<8>{});
#endif
#endif
}
if(lane_group_os == 0)
smem_cumsum(i_e + 1) = cnt;
// printf("i_e:%d, cnt:%d\n", i_e, cnt);
}
}
#endif
if constexpr(Problem::LocalExpertMasking)
{
@@ -897,43 +782,22 @@ struct MoeSortingKernel
for(int i_e = tid; i_e < num_experts; i_e += block_size)
{
// reuse this buffer
// printf("tid:%d, m:%d\n", tid, local_expert_mask[i_e]);
smem_cumdup(i_e + 1) = local_expert_mask[i_e];
}
}
__syncthreads();
#if 0
if(tid == 0)
{
(void)lid;
(void)wid;
for(int i = 1; i <= num_experts; ++i)
{
// printf("e:%d -- %d (%d) \n", i - 1, smem_cumsum(i), sub_tokens);
auto current_units = [&]() {
index_t x_ = smem_cumsum(i) + unit_size_mdiv.divisor - 1;
index_t y_ = unit_size_mdiv.div(x_);
return max(y_, 1) * unit_size_mdiv.divisor;
}();
smem_cumsum(i) = smem_cumsum(i - 1) + current_units;
}
*p_total_tokens_post_pad = smem_cumsum(num_experts);
}
__syncthreads();
#else
{
if(wid == 0)
{
// NOTE: under this block can never use __syncthreads!
int i_e_ = 0;
int local_cumsum_ = 0;
// int pre_cumsum_ = 0;
for(; i_e_ < num_experts; i_e_ += warpSize)
{
int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
// if((i_e_+lid) < num_experts)
int local_cnt = smem_cumsum(i_e_ + lid + 1);
int local_cnt = smem_cumsum(i_e_ + lid + 1);
int blocks_pers_expert =
unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
@@ -976,10 +840,7 @@ struct MoeSortingKernel
// pre_sumsum has value, which will result int
// zero local cumsum(but we want at least padded)
wave_cumsum<int, warpSize>(local_cumsum_);
// printf(" lid:%d(%d), local_cnt:%d,pre_cumsum_:%d, %d--> %d (m:%d)\n", lid,
// i_e_ +
// lid, local_cnt, pre_cumsum_, padded_tokens_per_expert,local_cumsum_
// ,local_masking);
if((i_e_ + lid) < num_experts)
smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
@@ -1003,7 +864,6 @@ struct MoeSortingKernel
}
__syncthreads();
}
#endif
for(int i_e = tid; i_e < num_experts; i_e += block_size)
{
@@ -1020,9 +880,6 @@ struct MoeSortingKernel
return i_e;
}();
// printf("i_e:%d, e_start:%d, e_end:%d, expert_id:%d (%d-%d, m:%d)\n", i_e, e_start,
// e_end, expert_id, e_start, e_end, local_expert_mask[i_e]);
smem_cumdup(i_e) = e_start; // duplicate cumsum for later use
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
@@ -1041,7 +898,6 @@ struct MoeSortingKernel
p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id;
}
}
// if (tid == 0)
smem_cumdup(num_experts) = smem_cumsum(num_experts);
// fill the p_sorted_token_ids/p_sorted_weights
@@ -1068,40 +924,13 @@ struct MoeSortingKernel
int i_t = i_token + curr_token_id;
if(i_t < tokens)
{
int eid = topk_id[i_t * topk + curr_topk_id];
// if(eid == 0) {
// printf("@@@ eid:%d, i_t:%d, cur:%d, curr_topk_id:%d\n", eid, i_t,
// curr_token_id, curr_topk_id); printf("## eid:%d,%d\n", i_t,
// curr_topk_id);
//}
int eid = topk_id[i_t * topk + curr_topk_id];
smem_tokens(curr_token_id, eid) = curr_topk_id + 1; // at least 1
}
}
__syncthreads();
}
#if 0
for(int eid = tid; eid < num_experts; eid += block_size) {
// indeed we can unroll 8x
for(int i_sub_token = 0; i_sub_token < sub_tokens; i_sub_token++) {
auto x = smem_tokens(i_sub_token, eid);
//if (eid == 0)
// printf("@@ eid:%d, pos:%d, i_sub_token:%d, x:%d\n", eid, smem_cumsum(eid), i_sub_token, x);
if(x != 0) {
// now x is topk value
int position = smem_cumsum(eid);
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[position] = MOE_SORTING_MOCK_ID(i_token + i_sub_token, x - 1);
#else
p_sorted_token_ids[position] = i_token + i_sub_token;
#endif
p_sorted_weights[position] = weights[(i_token + i_sub_token) * topk + x - 1];
smem_cumsum(eid) = position + 1; // increase position
}
// __syncthreads();
}
}
#else
{
constexpr int lane_group_sz = 8;
int lane_group_id = tid / lane_group_sz;
@@ -1138,17 +967,12 @@ struct MoeSortingKernel
int remote_cnt = __builtin_amdgcn_ds_bpermute(
(lane_group_sz * (lane_group_id + 1) - 1) << 2, local_cnt);
// printf("[%d]eid:%d, i_sub_token:%d, position:%d, x:%d, local_cnt:%d(%d),
// remote_cnt:%d\n",
// tid, eid, i_sub_token, position, x, local_cnt, local_cnt_cache,
// remote_cnt);
position += remote_cnt;
}
smem_cumsum(eid) = position;
}
}
#endif
// (void) weights;
__syncthreads();
}
@@ -1157,7 +981,6 @@ struct MoeSortingKernel
{
int e_start = smem_cumsum(eid);
int e_end = smem_cumdup(eid + 1);
// printf("--- eid:%d, e_start:%d, e_end:%d\n", eid, e_start, e_end);
if constexpr(Problem::SkipExpertsWithZeroTokens)
{
if(e_start == e_end) // skip zero token expert
@@ -1201,8 +1024,6 @@ struct MoeSortingKernel
static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
kargs.num_experts,
kargs.tokens,
// kargs.tokens_per_thread,
// numel,
kargs.unit_size_mdiv,
kargs.topk_mdiv,
kargs.expert_mdiv,