[CK_TILE]Moe update index (#1672)

* update MOCK_ID for moe-sorting

* add moe-smoothquant

* update a comment

* fix format

* hot fix

* update topk in overflow case

* update comments

* update bf16 cvt

---------

Co-authored-by: valarLip <340077269@qq.com>
This commit is contained in:
carlushuang
2024-11-25 13:12:35 +08:00
committed by GitHub
parent ce2bdf42a9
commit 36c7ce4e0e
36 changed files with 1321 additions and 11 deletions

View File

@@ -12,20 +12,77 @@
namespace ck_tile {
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// * Note on token_id_per_expert/sorted_token_ids_ptr data:
// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from
// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
//
// 32bit 0........23 24.....31 bit
// (data) -> (token_id | topk_id)
// low 24 bit is for token id, top 8 bit is for topk id
//
// the input after smooth-quant is [topk, token, hidden_dim], originally it is [token, hidden_dim]
// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
struct MoeSortingHostArgs
{
const void* p_topk_ids;
const void* p_weights;
const void* p_topk_ids; // [token, topk]
const void* p_weights; // [token, topk]
void* p_sorted_token_ids;
void* p_sorted_weights;
void* p_sorted_expert_ids;
void* p_total_tokens_post_pad;
// we fused the setzero of output of fused-moe buffer
// set this pointer to nullptr will skip this operation
void* p_moe_buf;
index_t tokens;
index_t unit_size;
index_t unit_size; // this is the M_a of fused-moe kernel
index_t num_experts;
index_t topk;
index_t moe_buf_bytes;
index_t moe_buf_bytes; // byte size of p_moe_buf
};
template <typename Problem_>
@@ -183,8 +240,14 @@ struct MoeSortingKernel
index_t expert_id = topk_id[i];
index_t rank_post_pad =
tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id];
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
uint32_t curr_token_id, curr_topk_id;
topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
p_sorted_token_ids[rank_post_pad] = MOE_SORTING_MOCK_ID(curr_token_id, curr_topk_id);
#else
p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
p_sorted_weights[rank_post_pad] = weights[i];
#endif
p_sorted_weights[rank_post_pad] = weights[i];
++tokens_cnts[calc_index(num_experts, tid, expert_id)];
}
@@ -195,8 +258,13 @@ struct MoeSortingKernel
cumsum[tid] + tokens_cnts[calc_index(num_experts, blockDim.x, tid)];
while(expert_offset < cumsum[tid + 1])
{
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
p_sorted_token_ids[expert_offset] =
MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
#else
p_sorted_token_ids[expert_offset] = prefill_token;
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
#endif
p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
expert_offset++;
}
}
@@ -229,4 +297,7 @@ struct MoeSortingKernel
smem);
}
};
#undef MOE_SORTING_MOCK_ID
} // namespace ck_tile

View File

@@ -3,6 +3,7 @@
#pragma once
#include "ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp"
#include "ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp"
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp"
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp"

View File

@@ -0,0 +1,205 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
// host side args
struct MoeSmoothquantHostArgs
{
const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
const void* p_xscale; // [experts, hidden_size], input, columnwise scale, fp32
const void* p_topk_ids; // [tokens, topk]
void* p_yscale; // [topk * tokens, 1], output, rowwise quant scale
void* p_qy; // [topk * tokens, hidden_size], output
index_t tokens;
index_t hidden_size;
index_t experts;
index_t topk;
index_t x_stride; // input x row stride
index_t y_stride; // output y stride(stride for topk)
};
// TODO: Extract some type to wrapper class
template <typename Pipeline_>
struct MoeSmoothquant
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = typename Pipeline::Problem;
using XDataType = remove_cvref_t<typename Problem::XDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using QYDataType = remove_cvref_t<typename Problem::QYDataType>;
static constexpr index_t Block_M = Problem::BlockShape::Block_M;
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
static constexpr bool kPadM = false; // always no need to pad along M
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kTwoPass = Problem::kTwoPass;
static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static_assert(Problem::BlockShape::Repeat_M == 1);
struct Kargs
{
const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
const void* p_xscale; // [experts, hidden_size], input, columnwise scale, fp32
const void* p_topk_ids; // [tokens, topk]
void* p_yscale; // [topk, tokens, 1], output, rowwise quant scale
void* p_qy; // [topk, tokens, hidden_size], output
index_t tokens;
index_t hidden_size;
index_t experts;
index_t topk;
index_t x_stride; // input x row stride
index_t y_stride; // output y stride(stride for topk)
};
using Hargs = MoeSmoothquantHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
return Kargs{hargs.p_x,
hargs.p_xscale,
hargs.p_topk_ids,
hargs.p_yscale,
hargs.p_qy,
hargs.tokens,
hargs.hidden_size,
hargs.experts,
hargs.topk,
hargs.x_stride,
hargs.y_stride};
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{
return dim3(hargs.topk, integer_divide_ceil(hargs.tokens, Block_M), 1);
}
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
// in byte
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
CK_TILE_HOST static std::string GetName()
{
// clang-format off
using S_ = typename Problem::BlockShape;
auto surfix = [&] () {
std::string n;
if (kPadN) n += "_pn";
if (kTwoPass) n += "_2p";
return n; }();
#define _SS_ std::string
#define _TS_ std::to_string
return _SS_("moe_smoothquant_") + _SS_(t2s<XDataType>::name) + "_" +
_TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
_TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
_SS_(Pipeline::name) + surfix;
#undef _SS_
#undef _TS_
// clang-format on
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
const index_t i_topk = blockIdx.x;
const index_t i_token = blockIdx.y * Block_M;
const index_t i_token_in_thrd =
__builtin_amdgcn_readfirstlane(threadIdx.x / Problem::BlockShape::ThreadPerBlock_N);
const index_t i_expert = reinterpret_cast<const index_t*>(
kargs.p_topk_ids)[(i_token + i_token_in_thrd) * kargs.topk + i_topk];
// [tokens ,hidden_size]
const auto x_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.tokens, kargs.hidden_size),
make_tuple(kargs.x_stride, 1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {i_token, 0});
}();
// [experts, hidden_size],
const auto xscale_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XScaleDataType*>(kargs.p_xscale) + i_expert * kargs.hidden_size,
make_tuple(kargs.hidden_size),
make_tuple(1),
number<Vector_N>{},
number<1>{});
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<kPadN>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
}();
// [topk, tokens]
auto yscale_window = [&]() {
const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<YScaleDataType*>(kargs.p_yscale) + i_topk * kargs.tokens,
make_tuple(kargs.tokens),
make_tuple(1),
number<1>{});
const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_M>{}), sequence<kPadM>{});
return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {i_token});
}();
// [topk, tokens, hidden_size]
auto qy_window = [&]() {
auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
static_cast<QYDataType*>(kargs.p_qy) + i_topk * kargs.tokens * kargs.y_stride,
make_tuple(kargs.tokens, kargs.hidden_size),
make_tuple(kargs.y_stride, 1),
number<Vector_N>{},
number<1>{});
auto tmp2_ = pad_tensor_view(
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
return make_tile_window(
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {i_token, 0});
}();
__shared__ char smem[GetSmemSize()];
Pipeline{}(x_window, xscale_window, yscale_window, qy_window, kargs.hidden_size, smem);
}
};
} // namespace ck_tile