fix moe gemm for not gate only

This commit is contained in:
zanzhang
2025-05-21 15:38:34 +08:00
parent 9d8a21dfa9
commit 299c63d198
3 changed files with 191 additions and 265 deletions

View File

@@ -117,6 +117,68 @@ struct BlockFlatmmASmemBSmemCRegV1
});
});
}
// C += A * B
template <typename CBlockTensor, typename ABlockTensor, typename BFlatBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
ABlockTensor& a_block_tensor,
BFlatBlockTensor& b_warp_tensor) const
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp =
BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN));
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
};
} // namespace ck_tile

View File

@@ -66,6 +66,7 @@ struct MoeGemmKernel
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
static constexpr bool IsInputGemm = FlatmmPipeline::IsInputGemm;
static constexpr bool IsGateOnly = FlatmmPipeline::IsGateOnly;
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
@@ -134,7 +135,6 @@ struct MoeGemmKernel
CK_TILE_HOST static constexpr MoeGemmKernelArgs MakeKernelArgs(const MoeGemmHostArgs& hostArgs)
{
printf("in moe gemm kernel args! \n");
return MoeGemmKernelArgs{hostArgs.p_sorted_token_ids,
hostArgs.p_sorted_expert_ids,
hostArgs.p_max_token_id,
@@ -263,70 +263,6 @@ struct MoeGemmKernel
number<1>{});
}();
// const auto& b_tensor_view = [&]() {
// if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
// {
// if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
// {
// constexpr index_t K1 = FlatmmPipeline::GetSmemPackB();
// const index_t K0 = splitk_batch_offset.splitted_k / K1;
// constexpr index_t VectorSizeB = std::min(K1, FlatmmPipeline::GetVectorSizeB());
// const auto b_k0_n_k1_desc =
// make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
// make_tuple(kargs.N * K1, K1, I1),
// number<VectorSizeB>{},
// number<1>{});
// const auto b_n_k_desc = transform_tensor_descriptor(
// b_k0_n_k1_desc,
// make_tuple(make_merge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(kargs.N)),
// make_tuple(sequence<0, 2>{}, sequence<1>{}),
// make_tuple(sequence<0>{}, sequence<1>{}));
// return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
// }
// else
// {
// return make_naive_tensor_view<address_space_enum::global>(
// b_ptr,
// make_tuple(splitk_batch_offset.splitted_k, kargs.N),
// make_tuple(kargs.stride_B, 1),
// number<FlatmmPipeline::GetVectorSizeB()>{},
// number<1>{});
// }
// }
// else
// {
// if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
// {
// constexpr index_t K1 = FlatmmPipeline::GetSmemPackB();
// const index_t K0 = splitk_batch_offset.splitted_k / K1;
// constexpr index_t VectorSizeB = std::min(K1, FlatmmPipeline::GetVectorSizeB());
// const auto b_k0_n_k1_desc =
// make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
// make_tuple(kargs.N * K1, K1, I1),
// number<VectorSizeB>{},
// number<1>{});
// const auto b_n_k_desc = transform_tensor_descriptor(
// b_k0_n_k1_desc,
// make_tuple(make_merge_transform(make_tuple(K0, K1)),
// make_pass_through_transform(kargs.N)),
// make_tuple(sequence<0, 2>{}, sequence<1>{}),
// make_tuple(sequence<1>{}, sequence<0>{}));
// return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
// }
// else
// {
// return make_naive_tensor_view<address_space_enum::global>(
// b_ptr,
// make_tuple(kargs.N, splitk_batch_offset.splitted_k),
// make_tuple(kargs.stride_B, 1),
// number<FlatmmPipeline::GetVectorSizeB()>{},
// number<1>{});
// }
// }
// }();
// TODO: enable vector write for C in ColMajor
const auto& c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
@@ -422,29 +358,6 @@ struct MoeGemmKernel
make_tuple(sequence<0>{}, sequence<1>{}));
}
// template <typename CView>
// CK_TILE_DEVICE static auto GetCTransformGemmView(const CView& view, const index_t token_id)
// {
// if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, CLayout>)
// return transform_tensor_view(
// view,
// make_tuple(make_indexing_transform(
// view.get_tensor_descriptor().get_length(number<0>()), token_id),
// make_pass_through_transform(
// view.get_tensor_descriptor().get_length(number<1>()))),
// make_tuple(sequence<0>{}, sequence<1>{}),
// make_tuple(sequence<0>{}, sequence<1>{}));
// else
// return transform_tensor_view(
// view,
// make_tuple(make_pass_through_transform(
// view.get_tensor_descriptor().get_length(number<0>())),
// make_indexing_transform(
// view.get_tensor_descriptor().get_length(number<1>()), token_id)),
// make_tuple(sequence<0>{}, sequence<1>{}),
// make_tuple(sequence<0>{}, sequence<1>{}));
// }
template <typename PadView>
CK_TILE_DEVICE static auto TransformGemmPadViews(const PadView& views, const index_t token_id)
{

View File

@@ -11,7 +11,7 @@
namespace ck_tile {
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
struct MoeGemmPipelineAgBgCrImpl
struct MoeGemmPipelineAgBgCrImpl : public FlatmmPipelineAGmemBGmemCRegV1
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
@@ -95,140 +95,6 @@ struct MoeGemmPipelineAgBgCrImpl
return MRepeat;
}
template <typename ADramBlockWindow, typename BFlatBlockWindowTmp, typename AElementFunction>
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindow& a_dram_block_window,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindow::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindow{}.get_window_lengths()[number<0>{}],
"wrong!");
static_assert(kKPerBlock == ADramBlockWindow{}.get_window_lengths()[number<1>{}],
"wrong!");
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// A LDS tile window for store
auto a_copy_lds_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM
auto block_flatmm = BlockFlatmm();
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
auto b_flat_dram_window = // tile_window_with_static_distribution
make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
// Acc register tile
auto c_block_tile = decltype(block_flatmm(a_lds_gemm_window, b_flat_dram_window)){};
// prefetch
// global read 0
auto a_block_tile = a_dram_block_window.load();
{
// move to 1
move_tile_window(a_dram_block_window, {0, kKPerBlock});
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
}
else
{
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
}
}
index_t iCounter = num_loop - 1;
while(iCounter > 0)
{
// global read i + 1
a_dram_block_window.load(a_block_tile);
block_sync_lds();
// GEMM i
block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window);
block_sync_lds();
// move to i + 2
move_tile_window(a_dram_block_window, {0, kKPerBlock});
// LDS write i + 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
// move to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
iCounter--;
}
// tail
{
block_sync_lds();
// GEMM num_loop - 1
block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window);
}
sweep_tile(c_block_tile,
[&](auto idx0, auto idx1) {
fp32x2_t v_{c_block_tile(idx0), c_block_tile(idx1)};
GateActivation{}(v_, v_);
c_block_tile(idx0) = v_.x;
c_block_tile(idx1) = v_.y;
},
sequence<1, 2>{});
return c_block_tile;
}
template <typename ADramBlockWindow, typename BFlatBlockWindowTmp>
CK_TILE_DEVICE auto operator()(ADramBlockWindow& a_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
return operator()(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_flat_dram_block_window_tmp,
num_loop,
p_smem);
}
template <typename ADramBlockWindow, typename BFlatBlockWindowTmp, typename AElementFunction>
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindow& a_dram_block_window,
const AElementFunction& a_element_func,
@@ -246,6 +112,26 @@ struct MoeGemmPipelineAgBgCrImpl
static_assert(kKPerBlock == ADramBlockWindow{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
@@ -268,106 +154,171 @@ struct MoeGemmPipelineAgBgCrImpl
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
auto b_gate_flat_dram_window =
auto b_gate_flat_dram_window = // tile_window_with_static_distribution
make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
b_flat_dram_block_window_tmp.move({N, 0})
auto b_up_flat_dram_window =
move_tile(b_flat_dram_block_window_tmp, {N, 0});
auto b_up_flat_dram_window = // tile_window_with_static_distribution
make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
// Acc register tile
using c_block_tile_type = decltype(block_flatmm(a_lds_gemm_window, b_gate_flat_dram_window));
auto c_block_tiles[2] = {c_block_tile_type{}, c_block_tile_type{}};
auto c_gate_block_tile = c_block_tile_type{};
auto c_up_block_tile = c_block_tile_type{}
// prefetch
// global read 0
auto a_block_tile = a_dram_block_window.load();
a_block_tile = load_tile(a_dram_block_window);
statically_indexed_array<
statically_indexed_array<decltype(b_gate_flat_dram_window), KIterPerWarp>,
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_gate_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
b_warp_tensor;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_up_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
b_warp_tensor_2;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_gate_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
{
// move to 1
move_tile_window(a_dram_block_window, {0, kKPerBlock});
// move to next flat K
move_tile_window(b_gate_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tiles[0]);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tiles[1]);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
}
else
{
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
}
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
block_sync_lds();
}
index_t iCounter = num_loop - 1;
while(iCounter > 0)
{
// global read i + 1
a_dram_block_window.load(a_block_tile);
block_sync_lds();
a_block_tile = load_tile(a_dram_block_window);
// GEMM i
block_flatmm(c_block_tiles[0], a_lds_gemm_window, b_gate_flat_dram_window);
//TODO: simply add b_gate flatmm
block_flatmm(c_block_tiles[1], a_lds_gemm_window, b_up_flat_dram_window);
block_flatmm(c_gate_block_tile, a_warp_windows, b_warp_tensor);
block_sync_lds();
// move to i + 2
move_tile_window(a_dram_block_window, {0, kKPerBlock});
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_up_flat_dram_window;
// LDS write i + 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// move to i + 2
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// move to next flat K
move_tile_window(b_up_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// GEMM i
block_flatmm(c_up_block_tile, a_warp_windows, b_warp_tensor_2);
block_sync_lds();
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_gate_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// move to i + 2
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// move to next flat K
move_tile_window(b_gate_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
move_tile_window(b_up_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
// LDS write i + 1
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
// HotLoopScheduler();
block_sync_lds();
iCounter--;
}
// tail
{
// GEMM i
block_flatmm(c_gate_block_tile, a_warp_windows, b_warp_tensor);
block_sync_lds();
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_up_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// HotLoopScheduler();
block_sync_lds();
// GEMM num_loop - 1
block_flatmm(c_block_tiles[0], a_lds_gemm_window, b_gate_flat_dram_window);
block_flatmm(c_block_tiles[1], a_lds_gemm_window, b_up_flat_dram_window);
block_flatmm(c_up_block_tile, a_warp_windows, b_warp_tensor_2);
}
sweep_tile(c_block_tiles[0],
sweep_tile(c_gate_block_tile,
[&](auto idx0, auto idx1) {
fp32x2_t v_{c_block_tiles[0].at(number<0>{})(idx0), c_block_tiles[0].at(number<0>{})(idx1)};
fp32x2_t v_{c_gate_block_tile.at(number<0>{})(idx0), c_gate_block_tile.at(number<0>{})(idx1)};
typename Problem::GateActivation{}(v_, v_);
c_block_tiles[0].at(number<0>{})(idx0) = v_.x;
c_block_tiles[0].at(number<0>{})(idx1) = v_.y;
c_gate_block_tile.at(number<0>{})(idx0) = v_.x;
c_gate_block_tile.at(number<0>{})(idx1) = v_.y;
},
sequence<1, 2>{});
auto c_block_tile =
tile_elementwise_in([&](const auto& a_, const auto& b_) { return a_ * b_; },
c_block_tiles[0],
c_block_tiles[1]);
c_gate_block_tile,
c_up_block_tile);
return c_block_tiles[0];
return c_block_tile;
}
template <typename ADramBlockWindow, typename BFlatBlockWindowTmp>