mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
fix moe gemm for not gate only
This commit is contained in:
@@ -117,6 +117,68 @@ struct BlockFlatmmASmemBSmemCRegV1
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensor, typename BFlatBlockTensor>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
ABlockTensor& a_block_tensor,
|
||||
BFlatBlockTensor& b_warp_tensor) const
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp =
|
||||
BlockTile::at(idxN) / (WarpTile::at(idxN) * BlockWarps::at(idxN));
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -66,6 +66,7 @@ struct MoeGemmKernel
|
||||
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
|
||||
|
||||
static constexpr bool IsInputGemm = FlatmmPipeline::IsInputGemm;
|
||||
static constexpr bool IsGateOnly = FlatmmPipeline::IsGateOnly;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
@@ -134,7 +135,6 @@ struct MoeGemmKernel
|
||||
|
||||
CK_TILE_HOST static constexpr MoeGemmKernelArgs MakeKernelArgs(const MoeGemmHostArgs& hostArgs)
|
||||
{
|
||||
printf("in moe gemm kernel args! \n");
|
||||
return MoeGemmKernelArgs{hostArgs.p_sorted_token_ids,
|
||||
hostArgs.p_sorted_expert_ids,
|
||||
hostArgs.p_max_token_id,
|
||||
@@ -263,70 +263,6 @@ struct MoeGemmKernel
|
||||
number<1>{});
|
||||
}();
|
||||
|
||||
|
||||
// const auto& b_tensor_view = [&]() {
|
||||
// if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
// {
|
||||
// if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
// {
|
||||
// constexpr index_t K1 = FlatmmPipeline::GetSmemPackB();
|
||||
// const index_t K0 = splitk_batch_offset.splitted_k / K1;
|
||||
// constexpr index_t VectorSizeB = std::min(K1, FlatmmPipeline::GetVectorSizeB());
|
||||
// const auto b_k0_n_k1_desc =
|
||||
// make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
// make_tuple(kargs.N * K1, K1, I1),
|
||||
// number<VectorSizeB>{},
|
||||
// number<1>{});
|
||||
// const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
// b_k0_n_k1_desc,
|
||||
// make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
// make_pass_through_transform(kargs.N)),
|
||||
// make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
// make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
// return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// return make_naive_tensor_view<address_space_enum::global>(
|
||||
// b_ptr,
|
||||
// make_tuple(splitk_batch_offset.splitted_k, kargs.N),
|
||||
// make_tuple(kargs.stride_B, 1),
|
||||
// number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
// number<1>{});
|
||||
// }
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
// {
|
||||
// constexpr index_t K1 = FlatmmPipeline::GetSmemPackB();
|
||||
// const index_t K0 = splitk_batch_offset.splitted_k / K1;
|
||||
// constexpr index_t VectorSizeB = std::min(K1, FlatmmPipeline::GetVectorSizeB());
|
||||
// const auto b_k0_n_k1_desc =
|
||||
// make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1),
|
||||
// make_tuple(kargs.N * K1, K1, I1),
|
||||
// number<VectorSizeB>{},
|
||||
// number<1>{});
|
||||
// const auto b_n_k_desc = transform_tensor_descriptor(
|
||||
// b_k0_n_k1_desc,
|
||||
// make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
// make_pass_through_transform(kargs.N)),
|
||||
// make_tuple(sequence<0, 2>{}, sequence<1>{}),
|
||||
// make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
// return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// return make_naive_tensor_view<address_space_enum::global>(
|
||||
// b_ptr,
|
||||
// make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
// make_tuple(kargs.stride_B, 1),
|
||||
// number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
// number<1>{});
|
||||
// }
|
||||
// }
|
||||
// }();
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -422,29 +358,6 @@ struct MoeGemmKernel
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
|
||||
// template <typename CView>
|
||||
// CK_TILE_DEVICE static auto GetCTransformGemmView(const CView& view, const index_t token_id)
|
||||
// {
|
||||
// if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, CLayout>)
|
||||
// return transform_tensor_view(
|
||||
// view,
|
||||
// make_tuple(make_indexing_transform(
|
||||
// view.get_tensor_descriptor().get_length(number<0>()), token_id),
|
||||
// make_pass_through_transform(
|
||||
// view.get_tensor_descriptor().get_length(number<1>()))),
|
||||
// make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
// make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
// else
|
||||
// return transform_tensor_view(
|
||||
// view,
|
||||
// make_tuple(make_pass_through_transform(
|
||||
// view.get_tensor_descriptor().get_length(number<0>())),
|
||||
// make_indexing_transform(
|
||||
// view.get_tensor_descriptor().get_length(number<1>()), token_id)),
|
||||
// make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
// make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
// }
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto TransformGemmPadViews(const PadView& views, const index_t token_id)
|
||||
{
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
|
||||
struct MoeGemmPipelineAgBgCrImpl
|
||||
struct MoeGemmPipelineAgBgCrImpl : public FlatmmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
@@ -95,140 +95,6 @@ struct MoeGemmPipelineAgBgCrImpl
|
||||
return MRepeat;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindow, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindow& a_dram_block_window,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindow::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindow{}.get_window_lengths()[number<0>{}],
|
||||
"wrong!");
|
||||
static_assert(kKPerBlock == ADramBlockWindow{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
auto block_flatmm = BlockFlatmm();
|
||||
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
|
||||
auto b_flat_dram_window = // tile_window_with_static_distribution
|
||||
make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_flatmm(a_lds_gemm_window, b_flat_dram_window)){};
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
auto a_block_tile = a_dram_block_window.load();
|
||||
|
||||
{
|
||||
// move to 1
|
||||
move_tile_window(a_dram_block_window, {0, kKPerBlock});
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
|
||||
shuffle_tile(a_shuffle_tmp, a_block_tile);
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
|
||||
}
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// global read i + 1
|
||||
a_dram_block_window.load(a_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM i
|
||||
block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(a_dram_block_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
|
||||
// move to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 1
|
||||
block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window);
|
||||
}
|
||||
|
||||
sweep_tile(c_block_tile,
|
||||
[&](auto idx0, auto idx1) {
|
||||
fp32x2_t v_{c_block_tile(idx0), c_block_tile(idx1)};
|
||||
GateActivation{}(v_, v_);
|
||||
c_block_tile(idx0) = v_.x;
|
||||
c_block_tile(idx1) = v_.y;
|
||||
},
|
||||
sequence<1, 2>{});
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindow, typename BFlatBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(ADramBlockWindow& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindow, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindow& a_dram_block_window,
|
||||
const AElementFunction& a_element_func,
|
||||
@@ -246,6 +112,26 @@ struct MoeGemmPipelineAgBgCrImpl
|
||||
static_assert(kKPerBlock == ADramBlockWindow{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
|
||||
constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
|
||||
|
||||
constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
@@ -268,106 +154,171 @@ struct MoeGemmPipelineAgBgCrImpl
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
|
||||
|
||||
auto b_gate_flat_dram_window =
|
||||
auto b_gate_flat_dram_window = // tile_window_with_static_distribution
|
||||
make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
b_flat_dram_block_window_tmp.move({N, 0})
|
||||
auto b_up_flat_dram_window =
|
||||
move_tile(b_flat_dram_block_window_tmp, {N, 0});
|
||||
auto b_up_flat_dram_window = // tile_window_with_static_distribution
|
||||
make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
// Acc register tile
|
||||
using c_block_tile_type = decltype(block_flatmm(a_lds_gemm_window, b_gate_flat_dram_window));
|
||||
auto c_block_tiles[2] = {c_block_tile_type{}, c_block_tile_type{}};
|
||||
auto c_gate_block_tile = c_block_tile_type{};
|
||||
auto c_up_block_tile = c_block_tile_type{}
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
auto a_block_tile = a_dram_block_window.load();
|
||||
a_block_tile = load_tile(a_dram_block_window);
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_gate_flat_dram_window), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_gate_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(load_tile(b_up_flat_dram_window)), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_2;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_gate_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
{
|
||||
// move to 1
|
||||
move_tile_window(a_dram_block_window, {0, kKPerBlock});
|
||||
|
||||
// move to next flat K
|
||||
move_tile_window(b_gate_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tiles[0]);
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tiles[1]);
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
|
||||
shuffle_tile(a_shuffle_tmp, a_block_tile);
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
|
||||
}
|
||||
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// global read i + 1
|
||||
a_dram_block_window.load(a_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
a_block_tile = load_tile(a_dram_block_window);
|
||||
|
||||
// GEMM i
|
||||
block_flatmm(c_block_tiles[0], a_lds_gemm_window, b_gate_flat_dram_window);
|
||||
|
||||
//TODO: simply add b_gate flatmm
|
||||
block_flatmm(c_block_tiles[1], a_lds_gemm_window, b_up_flat_dram_window);
|
||||
block_flatmm(c_gate_block_tile, a_warp_windows, b_warp_tensor);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(a_dram_block_window, {0, kKPerBlock});
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_up_flat_dram_window;
|
||||
|
||||
// LDS write i + 1
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// move to next flat K
|
||||
move_tile_window(b_up_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// GEMM i
|
||||
block_flatmm(c_up_block_tile, a_warp_windows, b_warp_tensor_2);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_gate_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// move to next flat K
|
||||
move_tile_window(b_gate_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
move_tile_window(b_up_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
|
||||
// HotLoopScheduler();
|
||||
block_sync_lds();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
// GEMM i
|
||||
block_flatmm(c_gate_block_tile, a_warp_windows, b_warp_tensor);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_up_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
|
||||
// HotLoopScheduler();
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 1
|
||||
block_flatmm(c_block_tiles[0], a_lds_gemm_window, b_gate_flat_dram_window);
|
||||
block_flatmm(c_block_tiles[1], a_lds_gemm_window, b_up_flat_dram_window);
|
||||
block_flatmm(c_up_block_tile, a_warp_windows, b_warp_tensor_2);
|
||||
}
|
||||
|
||||
sweep_tile(c_block_tiles[0],
|
||||
sweep_tile(c_gate_block_tile,
|
||||
[&](auto idx0, auto idx1) {
|
||||
fp32x2_t v_{c_block_tiles[0].at(number<0>{})(idx0), c_block_tiles[0].at(number<0>{})(idx1)};
|
||||
fp32x2_t v_{c_gate_block_tile.at(number<0>{})(idx0), c_gate_block_tile.at(number<0>{})(idx1)};
|
||||
typename Problem::GateActivation{}(v_, v_);
|
||||
c_block_tiles[0].at(number<0>{})(idx0) = v_.x;
|
||||
c_block_tiles[0].at(number<0>{})(idx1) = v_.y;
|
||||
c_gate_block_tile.at(number<0>{})(idx0) = v_.x;
|
||||
c_gate_block_tile.at(number<0>{})(idx1) = v_.y;
|
||||
},
|
||||
sequence<1, 2>{});
|
||||
|
||||
auto c_block_tile =
|
||||
tile_elementwise_in([&](const auto& a_, const auto& b_) { return a_ * b_; },
|
||||
c_block_tiles[0],
|
||||
c_block_tiles[1]);
|
||||
c_gate_block_tile,
|
||||
c_up_block_tile);
|
||||
|
||||
return c_block_tiles[0];
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindow, typename BFlatBlockWindowTmp>
|
||||
|
||||
Reference in New Issue
Block a user