[CK_TILE] Fix mock token id, support g1u1/g1u0 through same inline code block (#1808)

* fix mock token id

* prepare host for g1u1

* reformat inline-asm

* restructure uk_0

* restructure gate_up

* done

* change default to init=1

* update readme

* fix a bug in interleave pipeline

* rcp for silu
This commit is contained in:
carlushuang
2025-01-16 17:51:10 +08:00
committed by GitHub
parent 8c29e06f3c
commit 1ff50e78c6
20 changed files with 1927 additions and 1413 deletions

View File

@@ -73,7 +73,7 @@ void reference_fused_moe(
ck_tile::index_t tokens,
ck_tile::index_t experts,
ck_tile::index_t hidden_size,
ck_tile::index_t intermediate_size, // this size is for gate/up
ck_tile::index_t intermediate_size, // this size is for gate/up/down
ck_tile::index_t topk,
ck_tile::index_t gate_only)
{
@@ -82,19 +82,8 @@ void reference_fused_moe(
assert(sorted_expert_ids_host.get_num_of_dimension() == 1);
assert(num_sorted_tiles_host.get_element_size() == 1);
ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0] / block_m;
ck_tile::index_t intermediate_size_0 = intermediate_size;
ck_tile::index_t intermediate_size_1 = intermediate_size / (gate_only ? 1 : 2);
// TODO: better remove this in the future, or modify the token_id value
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
{
if(token_ids_host(token_id_, i_) == expert_id_)
return i_;
}
throw std::runtime_error("not correct token/expert pair\n");
return -1; // TODO: not correct!!
};
ck_tile::index_t intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2);
ck_tile::index_t intermediate_size_1 = intermediate_size;
ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
@@ -105,11 +94,31 @@ void reference_fused_moe(
if(i_tile >= num_sorted_tiles)
return;
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
ck_tile::index_t i_topk = i_token >> 24;
i_token &= 0xffffff;
if(i_token >= tokens)
return;
(void)token_ids_host;
#else
// TODO: better remove this in the future, or modify the token_id value
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
{
if(token_ids_host(token_id_, i_) == expert_id_)
return i_;
}
throw std::runtime_error("not correct token/expert pair\n");
return -1; // TODO: not correct!!
};
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
if(i_token >= tokens)
return;
ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
auto weight = sorted_weight_host.mData[i_flatten];
#endif
auto weight = sorted_weight_host.mData[i_flatten];
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
// first gemm