diff --git a/example/ck_tile/18_flatmm_uk/CMakeLists.txt b/example/ck_tile/18_flatmm_uk/CMakeLists.txt new file mode 100644 index 0000000000..f0fe2dc6d0 --- /dev/null +++ b/example/ck_tile/18_flatmm_uk/CMakeLists.txt @@ -0,0 +1,19 @@ +set(TILE_EXAPMLE_FLATMM_UK "tile_example_flatmm_uk") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding ${TILE_EXAPMLE_FLATMM_UK}") +file(GLOB INSTANCE_SRCS instances/*.cpp) +add_executable(${TILE_EXAPMLE_FLATMM_UK} EXCLUDE_FROM_ALL main.cpp) +target_include_directories(${TILE_EXAPMLE_FLATMM_UK} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${TILE_EXAPMLE_FLATMM_UK} PRIVATE ${INSTANCE_SRCS}) + +set(TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +list(APPEND TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a +list(APPEND TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4) # rta +# list(APPEND TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1) +# list(APPEND TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + +target_compile_options(${TILE_EXAPMLE_FLATMM_UK} PRIVATE ${TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS}) diff --git a/example/ck_tile/18_flatmm_uk/flatmm_uk.hpp b/example/ck_tile/18_flatmm_uk/flatmm_uk.hpp new file mode 100644 index 0000000000..57c1d17eb7 --- /dev/null +++ b/example/ck_tile/18_flatmm_uk/flatmm_uk.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/flatmm_uk.hpp" +#include + +// this is only a convenient structure for creating an example +// this is not part of the host API +template +struct FlatmmUkTypeConfig; + +template +struct FlatmmUkTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using GDataType = ck_tile::bf16_t; + using DDataType = ck_tile::bf16_t; + using AccDataType = float; + using ODataType = ck_tile::bf16_t; + using AScaleDataType = ck_tile::remove_cvref_t; + using GScaleDataType = ck_tile::remove_cvref_t; + using DScaleDataType = ck_tile::remove_cvref_t; + using YSmoothScaleDataType = ck_tile::remove_cvref_t; + using TopkWeightDataType = ck_tile::remove_cvref_t; + using IndexDataType = ck_tile::index_t; +}; + +template +struct FlatmmUkTypeConfig +{ + using ADataType = ck_tile::fp16_t; + using GDataType = ck_tile::fp16_t; + using DDataType = ck_tile::fp16_t; + using AccDataType = float; + using ODataType = ck_tile::fp16_t; + using AScaleDataType = ck_tile::remove_cvref_t; + using GScaleDataType = ck_tile::remove_cvref_t; + using DScaleDataType = ck_tile::remove_cvref_t; + using YSmoothScaleDataType = ck_tile::remove_cvref_t; + using TopkWeightDataType = ck_tile::remove_cvref_t; + using IndexDataType = ck_tile::index_t; +}; + +template +struct FlatmmUkTypeConfig +{ + using ADataType = ck_tile::int8_t; + using GDataType = ck_tile::int8_t; + using DDataType = ck_tile::int8_t; + using AccDataType = int32_t; + using ODataType = ck_tile::bf16_t; + using AScaleDataType = ck_tile::remove_cvref_t; + using GScaleDataType = ck_tile::remove_cvref_t; + using DScaleDataType = ck_tile::remove_cvref_t; + using YSmoothScaleDataType = ck_tile::remove_cvref_t; + using TopkWeightDataType = ck_tile::remove_cvref_t; + using IndexDataType = ck_tile::index_t; +}; + + +struct flatmm_uk_args +{ + const void* a_ptr; // [m, k], input token + const void* b_ptr; // [m, k], input token + const void* c_ptr; // [m, k], output token (no need to do zeroing) + void* d_ptr; // [m, k], output token (no need to do zeroing) + void* dbg_int_ptr; // [m, k], output token (no need to do zeroing) + void* dbg_bf16_ptr; // [m, k], output token (no need to do zeroing) + void* dbg_fp32_ptr; // [m, k], output token (no need to do zeroing) + + ck_tile::index_t block_m; // block_m, used to devide the input + ck_tile::index_t hidden_size; // k + ck_tile::index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2 + ck_tile::index_t num_tokens; // input number of tokens for current iteration + ck_tile::index_t num_experts; // number of groups + ck_tile::index_t topk; // need this? + + ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size +}; + +// This is the public API, will be generated by script +struct flatmm_uk_traits +{ + std::string prec_i; // input precision + std::string prec_w; // weight precision + std::string prec_o; // output precision + std::string prec_st; // token scale data type + std::string prec_sw; // weight scale data type + std::string prec_sq; // smooth quant scale + std::string prec_kw; // topk-weight data type + int block_m; + int gate_only; + int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant +}; + +float flatmm_uk(flatmm_uk_traits, flatmm_uk_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/18_flatmm_uk/instances/flatmm_uk_api.cpp b/example/ck_tile/18_flatmm_uk/instances/flatmm_uk_api.cpp new file mode 100644 index 0000000000..a32f82ac9f --- /dev/null +++ b/example/ck_tile/18_flatmm_uk/instances/flatmm_uk_api.cpp @@ -0,0 +1,192 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "flatmm_uk.hpp" +#include "flatmm_uk_api.hpp" +#include "ck_tile/ops/flatmm_uk.hpp" +#include + +template +using S = ck_tile::sequence; + +// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j +template +float flatmm_uk_(const ck_tile::stream_config& s_, flatmm_uk_args_ a_) +{ + printf("[FF] ======= fused_moegemm_() ======= \n \tget moe arg in a_ , get " + "config in Ts_\n"); + using f_traits = ck_tile::FusedMoeGemmTraits; + using f_shape = ck_tile::FusedMoeGemmShape; + printf("[FF] --- fused_moegemm_(): --- \n"); + printf("[FF] f_shape::BlockSize = %d\n", static_cast(f_shape::BlockSize)); + printf("[FF] f_shape::NumWarps = %d\n", static_cast(f_shape::NumWarps)); + printf("[FF] --------- \n"); + printf("[FF] f_shape::Block_M0 = %d\n", static_cast(f_shape::Block_M0)); + printf("[FF] f_shape::Block_N0 = %d\n", static_cast(f_shape::Block_N0)); + printf("[FF] f_shape::Block_K0 = %d\n", static_cast(f_shape::Block_K0)); + printf("[FF] f_shape::WarpPerBlock_M0 = %d\n", static_cast(f_shape::WarpPerBlock_M0)); + printf("[FF] f_shape::WarpPerBlock_N0 = %d\n", static_cast(f_shape::WarpPerBlock_N0)); + printf("[FF] f_shape::WarpPerBlock_K0 = %d\n", static_cast(f_shape::WarpPerBlock_K0)); + printf("[FF] f_shape::Warp_M0 = %d\n", static_cast(f_shape::Warp_M0)); + printf("[FF] f_shape::Warp_N0 = %d\n", static_cast(f_shape::Warp_N0)); + printf("[FF] f_shape::Warp_K0 = %d\n", static_cast(f_shape::Warp_K0)); + printf("[FF] f_shape::ThreadPerBlock_M0 = %d\n", + static_cast(f_shape::ThreadPerBlock_M0)); + printf("[FF] f_shape::ThreadPerBlock_N0 = %d\n", + static_cast(f_shape::ThreadPerBlock_N0)); + printf("[FF] f_shape::ThreadPerBlock_K0 = %d\n", + static_cast(f_shape::ThreadPerBlock_K0)); + printf("[FF] f_shape::Repeat_M0 = %d\n", static_cast(f_shape::Repeat_M0)); + printf("[FF] f_shape::Repeat_N0 = %d\n", static_cast(f_shape::Repeat_N0)); + printf("[FF] f_shape::Repeat_K0 = %d\n", static_cast(f_shape::Repeat_K0)); + printf("[FF] f_shape::Block_W0 = %d\n", static_cast(f_shape::Block_W0)); + printf("[FF] f_shape::Block_Nr0 = %d\n", static_cast(f_shape::Block_Nr0)); + printf("[FF] f_shape::Block_Kr0 = %d\n", static_cast(f_shape::Block_Kr0)); + printf("[FF] --------- \n"); + printf("[FF] f_shape::Block_M1 = %d\n", static_cast(f_shape::Block_M1)); + printf("[FF] f_shape::Block_N1 = %d\n", static_cast(f_shape::Block_N1)); + printf("[FF] f_shape::Block_K1 = %d\n", static_cast(f_shape::Block_K1)); + printf("[FF] f_shape::WarpPerBlock_M1 = %d\n", static_cast(f_shape::WarpPerBlock_M1)); + printf("[FF] f_shape::WarpPerBlock_N1 = %d\n", static_cast(f_shape::WarpPerBlock_N1)); + printf("[FF] f_shape::WarpPerBlock_K1 = %d\n", static_cast(f_shape::WarpPerBlock_K1)); + printf("[FF] f_shape::Warp_M1 = %d\n", static_cast(f_shape::Warp_M1)); + printf("[FF] f_shape::Warp_N1 = %d\n", static_cast(f_shape::Warp_N1)); + printf("[FF] f_shape::Warp_K1 = %d\n", static_cast(f_shape::Warp_K1)); + printf("[FF] f_shape::ThreadPerBlock_M1 = %d\n", + static_cast(f_shape::ThreadPerBlock_M1)); + printf("[FF] f_shape::ThreadPerBlock_N1 = %d\n", + static_cast(f_shape::ThreadPerBlock_N1)); + printf("[FF] f_shape::ThreadPerBlock_K1 = %d\n", + static_cast(f_shape::ThreadPerBlock_K1)); + printf("[FF] f_shape::Repeat_M1 = %d\n", static_cast(f_shape::Repeat_M1)); + printf("[FF] f_shape::Repeat_N1 = %d\n", static_cast(f_shape::Repeat_N1)); + printf("[FF] f_shape::Repeat_K1 = %d\n", static_cast(f_shape::Repeat_K1)); + printf("[FF] f_shape::Block_W1 = %d\n", static_cast(f_shape::Block_W1)); + printf("[FF] f_shape::Block_Nr1 = %d\n", static_cast(f_shape::Block_Nr1)); + printf("[FF] f_shape::Block_Kr1 = %d\n", static_cast(f_shape::Block_Kr1)); + using f_problem = + ck_tile::FusedMoeGemmPipelineProblem; + + // using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx; + using f_pipeline = ck_tile::GemmPipeline_FlatmmUk; + using f_kernel = ck_tile::FlatmmUkKernel; + + const dim3 grids = f_kernel::GridSize(a_); + constexpr dim3 blocks = f_kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + printf("[FF] grids = [%d, %d, %d]\n", grids.x, grids.y, grids.z); + printf("[FF] blocks = [%d, %d, %d]\n", blocks.x, blocks.y, blocks.z); + + static int printed = 0; + + auto kargs = f_kernel::MakeKargs(a_); + f_kernel kernel{}; + auto lambda_kenrel = + ck_tile::make_kernel(kernel, grids, blocks, 0, kargs); + + if(s_.log_level_ > 0 && printed == 10) + { + // std::cout << ", " << f_kernel::GetName() << std::flush; + printed = 1; + } + + return ck_tile::launch_kernel( + s_, lambda_kenrel + // ck_tile::make_kernel(f_kernel{}, grids, blocks, 0, kargs) + ); +} + +float flatmm_uk(flatmm_uk_traits t, flatmm_uk_args a, const ck_tile::stream_config& s) +{ + // auto s_ = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1}; + auto s_ = s; + + auto t_ = flatmm_uk_traits_{t.prec_i, + t.prec_w, + t.prec_o, + t.prec_st, + t.prec_sw, + t.prec_sq, + t.prec_kw, + t.block_m, + t.gate_only, + t.fused_quant}; + auto a_ = flatmm_uk_args_{ + a.a_ptr, // const void* a_ptr; + a.b_ptr, // const void* a_ptr; + a.c_ptr, // void* o_ptr; + a.d_ptr, // void* o_ptr; + a.dbg_int_ptr, + a.dbg_bf16_ptr, + a.dbg_fp32_ptr, + a.hidden_size, // index_t hidden_size; + a.intermediate_size, // index_t intermediate_size; + a.num_tokens, // index_t num_tokens; + a.num_experts, // index_t num_experts; + a.topk, // index_t topk; + a.stride_token // index_t stride_token; + }; + + float r = -1; + + if(t_.prec_i == "bf16" && t_.prec_w == "bf16" && t_.prec_o == "bf16" && t_.prec_st == "fp32" && + t_.prec_sw == "fp32" && t_.prec_sq == "fp32" && t_.prec_kw == "fp32" && t_.block_m == 32 && + t_.gate_only == 1) + { + using t_ = fmoe_, + S<1, 4, 1>, + S<16, 16, 32>, + 1, + 0>; + r = flatmm_uk_(s_, a_); + } + else if(t_.prec_i == "fp16" && t_.prec_w == "fp16" && t_.prec_o == "fp16" && + t_.prec_st == "fp32" && t_.prec_sw == "fp32" && t_.prec_sq == "fp32" && + t_.prec_kw == "fp32" && t_.block_m == 32 && t_.gate_only == 1) + { + using t_ = fmoe_, + S<1, 4, 1>, + S<16, 16, 32>, + 1, + 0>; + r = flatmm_uk_(s_, a_); + } + + // keep unsupported case return negative + if(r < 0) + return -1; + + return r; +} diff --git a/example/ck_tile/18_flatmm_uk/instances/flatmm_uk_api.hpp b/example/ck_tile/18_flatmm_uk/instances/flatmm_uk_api.hpp new file mode 100644 index 0000000000..c8850b574c --- /dev/null +++ b/example/ck_tile/18_flatmm_uk/instances/flatmm_uk_api.hpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/flatmm_uk.hpp" +#include + +// runtime args +struct flatmm_uk_args_ : public ck_tile::FlatmmUkHostArgs +{ +}; + +// This is the public API, will be generated by script +struct flatmm_uk_traits_ +{ + std::string prec_i; // input precision + std::string prec_w; // weight precision + std::string prec_o; // output precision + std::string prec_st; // token scale data type + std::string prec_sw; // weight scale data type + std::string prec_sq; // smooth quant scale + std::string prec_kw; // topk-weight data type + int block_m; + int gate_only; + int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant +}; + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template + typename WarpPerBlock_, + typename WarpTile_, // seq<*,*,*>, used to select mfma + ck_tile::index_t GateOnly_ = 0, + ck_tile::index_t FusedQuant_ = 0> +struct fmoe_ // traits, ugly name, only used for internal +{ + using TypeConfig = FlatmmUkTypeConfig; + + using ADataType = ck_tile::remove_cvref_t; + using GDataType = ck_tile::remove_cvref_t; + using DDataType = ck_tile::remove_cvref_t; + using AccDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + using AScaleDataType = ck_tile::remove_cvref_t; + using GScaleDataType = ck_tile::remove_cvref_t; + using DScaleDataType = ck_tile::remove_cvref_t; + using YSmoothScaleDataType = ck_tile::remove_cvref_t; + using TopkWeightDataType = ck_tile::remove_cvref_t; + using IndexDataType = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token + static constexpr ck_tile::index_t BI_ = + BlockTIle_::at(ck_tile::number<1>{}); // block intermediate + static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden + static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down + + using BlockTile_0 = ck_tile::sequence; + using WarpPerBlock_0 = ck_tile::remove_cvref_t; + using WarpTile_0 = ck_tile::remove_cvref_t; + + using BlockTile_1 = ck_tile::sequence; + using WarpPerBlock_1 = ck_tile::remove_cvref_t; + using WarpTile_1 = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t GateOnly = GateOnly_; + static constexpr ck_tile::index_t FusedQuant = FusedQuant_; +}; diff --git a/example/ck_tile/18_flatmm_uk/main.cpp b/example/ck_tile/18_flatmm_uk/main.cpp new file mode 100644 index 0000000000..a4b2ca6911 --- /dev/null +++ b/example/ck_tile/18_flatmm_uk/main.cpp @@ -0,0 +1,692 @@ +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "flatmm_uk.hpp" + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template +CK_TILE_HOST void my_reference_gemm(const ck_tile::HostTensor& a_m_k, + const ck_tile::HostTensor& b_k_n, + ck_tile::HostTensor& c_m_n, + float t, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) +{ + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(0); + const std::size_t K = a_m_k.get_length(1); + printf("[REF] M = %zu, N = %zu, K = %zu\n", M, N, K); + + auto cal_tflops = [&](auto ms) { + double flop_gemm = 2.0 * M * N * K; + return (flop_gemm) / (static_cast(ms) * 1e-3) / 1e12; + }; + + auto cal_tbps = [&](auto ms) { + double a_bytes = static_cast(M) * K * sizeof(ADataType); + double b_bytes = static_cast(N) * K * sizeof(BDataType); + double o_bytes = static_cast(M) * N * sizeof(CDataType); + + return (a_bytes + b_bytes + o_bytes) / (static_cast(ms) * 1e-3) / 1e12; + }; + + std::cout << ", " << t * 1.E3 << " us, " << cal_tflops(t) << " tflops, " << cal_tbps(t) + << " TB/s" << std::endl + << std::flush; + + auto f_mn = [&](auto m, auto n) { + AccDataType v_acc = 0; + + for(std::size_t k = 0; k < K; ++k) + { + ADataType v_a = a_element_op(a_m_k(m, k)); + BDataType v_b = b_element_op(b_k_n(n, k)); + + v_acc += + ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); + } + + c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); + }; + + ck_tile::make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); +} + +// mfma_type, 0:32x32, 1:16x16 +// TODO: padding? +template +auto shuffle_moe_weight(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type = 0) +{ + assert(t.get_lengths().size() == 3); + int b_ = t.get_lengths()[0]; + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[2]; + if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) + { + ck_tile::HostTensor t_view({b_, n_ / 32, 32, k_ / 16, 2, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); + } + else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1) + { + ck_tile::HostTensor t_view({b_, n_ / 16, 16, k_ / 32, 4, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0) + { + ck_tile::HostTensor t_view({b_, n_ / 32, 32, k_ / 32, 2, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1) + { + ck_tile::HostTensor t_view({b_, n_ / 16, 16, k_ / 64, 4, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5}); + } + return t; +} +template +auto shuffle_weight(const ck_tile::HostTensor& t, std::string mfma_dtype, int mfma_type = 0) +{ + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[0]; + int k_ = t.get_lengths()[1]; + if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) + { + ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 16, 2, 8}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1) + { + ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 32, 4, 8}); + printf("[FF] permute: n_ = %d, k_ = %d, n_/16 = %d, k_/32 = %d\n", n_, k_, n_ / 16, k_ / 32); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0) + { + ck_tile::HostTensor t_view({n_ / 32, 32, k_ / 32, 2, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1) + { + ck_tile::HostTensor t_view({n_ / 16, 16, k_ / 64, 4, 16}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + return t; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "64", "num of m") + .insert("n", "1024", "num of n") + .insert("k", "8192", "num of k") + .insert("t", "64", "num input tokens") + .insert("e", "8", "num of experts") + .insert("tk", "1", "topk") + .insert("h", "4096", "hidden_size of this model") + .insert("i", "4096", "intermediate_size between 2 gemms of FFN") + .insert("stride", "-1", "stride per row, if -1 then equal to hidden_size") + .insert("bm", "32", "blocking factor for sorted tokens") + .insert("tp", "8", "tensor parallel size") + .insert("v", "1", "cpu validation or not") + .insert("kname", "1", "print kernel name or not") + .insert("prec_i", "bf16", "input precision") + .insert("prec_w", "bf16", "weight precision") + .insert("prec_o", "bf16", "output precision") + .insert("prec_st", "auto", "token scale data type. auto will set to fp32") + .insert("prec_sw", "auto", "weight scale data type. auto will set to fp32") + .insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32") + .insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32") + .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") + .insert( + "gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate") + .insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm") + .insert("balance", + "0", + "if set to 1, will try balance the expert in topk-ids(convenient for testing)") + .insert("init", + "2", + "init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized" + "normalized(slow)") + .insert("seed", "11939", "seed used to do random") + .insert("warmup", "1", "cold iter") + .insert("repeat", "4", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type, +// SQ:smooth-quant-type, KW:topk-weight-type +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + printf("[FF] M = %d, N = %d, K = %d\n", M, N, K); + + ck_tile::index_t experts = arg_parser.get_int("e"); + ck_tile::index_t topk = arg_parser.get_int("tk"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + ck_tile::index_t block_m = arg_parser.get_int("bm"); + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_w = arg_parser.get_str("prec_w"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_st = arg_parser.get_str("prec_st"); + std::string prec_sw = arg_parser.get_str("prec_sw"); + std::string prec_sq = arg_parser.get_str("prec_sq"); + std::string prec_kw = arg_parser.get_str("prec_kw"); + prec_st = (prec_st == "auto") ? "fp32" : prec_st; + prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; + prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; + prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int fused_quant = arg_parser.get_int("fquant"); + int gate_only = arg_parser.get_int("gate_only"); + int init = arg_parser.get_int("init"); + uint32_t seed = arg_parser.get_uint32("seed"); + + using TypeConfig = FlatmmUkTypeConfig; + using ADataType = typename TypeConfig::ADataType; + using BDataType = ADataType; + using AccDataType = typename TypeConfig::AccDataType; + using CDataType = AccDataType; + using DDataType = AccDataType; + + // host verify + ck_tile::HostTensor a_host({M, K}); + ck_tile::HostTensor b_host({N, K}); + ck_tile::HostTensor c_host({M, N}); + ck_tile::HostTensor d_host({M, N}); + + ck_tile::HostTensor dbg_int({M * N, K}); + ck_tile::HostTensor dbg_fp32({M * N, K}); + ck_tile::HostTensor dbg_bf16({M * N, K}); + + if(init == 0) + { + ck_tile::FillStepRange{-.5f, .5f, 0.01f}(a_host); + ck_tile::FillStepRange{-.5f, .5f, 0.01f}(b_host); + } + else if(init == 1) + { + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f, seed, true}(b_host); + } + else if(init == 2) + { + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(a_host); + ck_tile::FillNormalDistribution{0.f, 1.f, seed, true}(b_host); + } + /* + // a_host + { + int X = static_cast(K); + int Y = static_cast(M); + + for(int y = 0; y < Y; y++) + { + for(int x = 0; x < X; x++) + { + int idx = X * y + x; + a_host.mData[idx] = ck_tile::type_convert(x * 1.0f); + //b_host.mData[idx] = ck_tile::type_convert(y * 1.0f); + //b_host.mData[idx] = ck_tile::type_convert(y*1.f + x * 0.0001f); + } + } + } + // b_host + { + int X = static_cast(K); + int Y = static_cast(N); + + for(int y = 0; y < Y; y++) + { + for(int x = 0; x < X; x++) + { + int idx = X * y + x; + b_host.mData[idx] = ck_tile::type_convert(idx * 1.0f); + //b_host.mData[idx] = ck_tile::type_convert(y * 1.0f); + //b_host.mData[idx] = ck_tile::type_convert(y*1.f + x * 0.0001f); + } + } + }*/ + + // permute weight + ck_tile::HostTensor b_perm_host = shuffle_weight(b_host, prec_w, 1); + + ck_tile::DeviceMem a_buf(a_host); + ck_tile::DeviceMem b_buf(b_perm_host); // b_host -> b_perm_host + ck_tile::DeviceMem c_buf(c_host); + ck_tile::DeviceMem d_buf(d_host); + ck_tile::DeviceMem dbg_int_buf(dbg_int); + ck_tile::DeviceMem dbg_bf16_buf(dbg_bf16); + ck_tile::DeviceMem dbg_fp32_buf(dbg_fp32); + + flatmm_uk_traits traits{prec_i, + prec_w, + prec_o, + prec_st, + prec_sw, + prec_sq, + prec_kw, + block_m, + gate_only, + fused_quant}; + printf("[FF] --- run(): ---\n"); + printf("[FF] traits.prec_i = %s\n", traits.prec_i.c_str()); + printf("[FF] traits.prec_w = %s\n", traits.prec_w.c_str()); + printf("[FF] traits.prec_o = %s\n", traits.prec_o.c_str()); + printf("[FF] traits.prec_st = %s\n", traits.prec_st.c_str()); + printf("[FF] traits.prec_sw = %s\n", traits.prec_sw.c_str()); + printf("[FF] traits.prec_sq = %s\n", traits.prec_sq.c_str()); + printf("[FF] traits.prec_kw = %s\n", traits.prec_kw.c_str()); + printf("[FF] traits.block_m = %d\n", traits.block_m); + printf("[FF] traits.gate_only = %d\n", traits.gate_only); + printf("[FF] traits.fused_quant = %d\n", traits.fused_quant); + + flatmm_uk_args args{a_buf.GetDeviceBuffer(), + b_buf.GetDeviceBuffer(), + c_buf.GetDeviceBuffer(), + d_buf.GetDeviceBuffer(), + dbg_int_buf.GetDeviceBuffer(), + dbg_bf16_buf.GetDeviceBuffer(), + dbg_fp32_buf.GetDeviceBuffer(), + block_m, + K, + N, + M, + experts, + topk, + stride}; + printf("[FF] --- run(): ---\n"); + printf("[FF] args.block_m = %d\n", args.block_m); + printf("[FF] args.hidden_size = %d\n", args.hidden_size); + printf("[FF] args.intermediate_size = %d\n", args.intermediate_size); + printf("[FF] args.num_tokens = %d\n", args.num_tokens); // 1 + printf("[FF] args.topk = %d\n", args.topk); // 0 + printf("[FF] args.num_experts = %d\n", args.num_experts); // 0 + printf("[FF] args.stride_token = %d\n", args.stride_token); + + float ave_time = flatmm_uk( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + if(ave_time < 0) + { + std::cout << " not supported!" << std::endl << std::flush; + return false; + } + + bool pass = true; + + if(do_validation) + { + auto d_dev = d_buf.ToHost(); + std::cout << std::endl << " =================== " << std::endl; + d_host.SetZero(); + my_reference_gemm( + a_host, b_host, d_host, ave_time); + pass = ck_tile::check_err(d_dev, d_host); + std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl; + } + +#if 0 + int GridDimX = 2; + int GridDimY = 1; + int BlockDimX = 64; + int BlockDimY = 4; + int BlockSize = BlockDimX * BlockDimY; + // dbg_int + { + auto dbg_int_dev = dbg_int_buf.ToHost(); + std::ofstream file("ff_dbg_int.txt"); + file << " [dbg_int]: Grid = [" << GridDimX << ", " << GridDimY << "], Block = " << BlockSize + << std::endl; + + for(int bidy = 0; bidy < GridDimY; bidy++) + { + for(int bidx = 0; bidx < GridDimX; bidx++) + { + file << "\n ========== block : [" << bidx << ", " << bidy << "] =========="; + for(int tid = 0; tid < BlockSize; tid++) + { + int gid = (BlockSize * GridDimX) * bidy + BlockSize * bidx + tid; + if(tid % 64 == 0) + { + file << "\n [" << tid << " : " << tid + 63 << "]: "; + } + file << ck_tile::type_convert(dbg_int_dev.mData[gid]) << ", "; + } + } + } + + file.close(); + } + // dbg_bf16 ---> kernel + { + auto dbg_bf16_dev = dbg_bf16_buf.ToHost(); + std::ofstream file("ff_dbg_bf16_kernel.txt"); + file << " [dbg_bf16]: Grid = [" << GridDimX << ", " << GridDimY + << "], Block = " << BlockSize << std::endl; + + for(int bidy = 0; bidy < GridDimY; bidy++) + { + for(int bidx = 0; bidx < GridDimX; bidx++) + { + file << "\n ========== block : [" << bidx << ", " << bidy << "] =========="; + for(int tid = 0; tid < BlockSize; tid++) + { + int gid = (BlockSize * bidx) * bidy + BlockSize * bidx + tid; + + file << "\n [" << tid << "]: "; + for(int i = 0; i < 64; i++) // multi output per thread + file << ck_tile::type_convert(dbg_bf16_dev.mData[gid * 64 + i]) + << ", "; + } + } + } + + file.close(); + } + // dbg_bf16 + { + auto dbg_bf16_dev = dbg_bf16_buf.ToHost(); + std::ofstream file("ff_dbg_bf16.txt"); + int X = static_cast(N); + int Y = static_cast(M); + file << " [dbg_bf16]: Row = " << Y << ", Col = " << X << std::endl; + + for(int m = 0; m < Y; m++) + { + file << "\n ========== row : [" << m << " / " << Y << "] =========="; + for(int n = 0; n < X; n++) + { + if(n % 64 == 0) + { + file << "\n [" << n << " : " << n + 63 << "]: "; + } + int idx = X * m + n; + file << ck_tile::type_convert(dbg_bf16_dev.mData[idx]) << ", "; + } + } + + file.close(); + } + // dbg_fp32 ---> kernel + { + auto dbg_fp32_dev = dbg_fp32_buf.ToHost(); + std::ofstream file("ff_dbg_fp32_kernel.txt"); + file << " [dbg_fp32]: Grid = [" << GridDimX << ", " << GridDimY + << "], Block = " << BlockSize << std::endl; + + for(int bidy = 0; bidy < GridDimY; bidy++) + { + for(int bidx = 0; bidx < GridDimX; bidx++) + { + file << "\n ========== block : [" << bidx << ", " << bidy << "] =========="; + for(int tid = 0; tid < BlockSize; tid++) + { + int gid = (BlockSize * bidx) * bidy + BlockSize * bidx + tid; + + file << "\n [" << tid << "]: "; + for(int i = 0; i < 64; i++) // multi output per thread + file << ck_tile::type_convert(dbg_fp32_dev.mData[gid * 64 + i]) + << ", "; + + // if(tid % 64 == 0) // one output per thread + // file << "\n [" << tid << " : " << tid + 63 << "]: "; + // file << ck_tile::type_convert(dbg_bf16.mData[gid]) << ", "; + } + } + } + + file.close(); + } + // dbg_fp32 + { + auto dbg_fp32_dev = dbg_fp32_buf.ToHost(); + std::ofstream file("ff_dbg_fp32.txt"); + int X = static_cast(N); + int Y = static_cast(M); + file << " [dbg_fp32]: Row = " << Y << ", Col = " << X << std::endl; + + for(int m = 0; m < Y; m++) + { + file << "\n ========== row : [" << m << " / " << Y << "] =========="; + for(int n = 0; n < X; n++) + { + if(n % 64 == 0) + { + file << "\n [" << n << " : " << n + 63 << "]: "; + } + int idx = X * m + n; + file << ck_tile::type_convert(dbg_fp32_dev.mData[idx]) << ", "; + } + } + + file.close(); + } + // a_host + { + std::ofstream file("ff_a_host.txt"); + int X = static_cast(K); + int Y = static_cast(M); + file << " [a_host]: Row = " << Y << ", Col = " << X << std::endl; + + for(int y = 0; y < Y; y++) + { + file << "\n ========== row : [" << y << " / " << Y << "] =========="; + for(int x = 0; x < X; x++) + { + int idx = X * y + x; + if(idx % 16 == 0) + { + file << "\n [" << x << " : " << x + 15 << " ]: "; + } + + file << ck_tile::type_convert(a_host.mData[idx]) << ", "; + } + } + + file.close(); + } + // b_host + { + std::ofstream file("ff_b_host.txt"); + int X = static_cast(K); + int Y = static_cast(N); + file << " [b_host]: Row = " << Y << ", Col = " << X << std::endl; + + for(int y = 0; y < Y; y++) + { + file << "\n ========== row : [" << y << " / " << Y << "] =========="; + for(int x = 0; x < X; x++) + { + int idx = X * y + x; + if(idx % 16 == 0) + { + file << "\n [" << x << " : " << x + 15 << " ]: "; + } + + file << ck_tile::type_convert(b_host.mData[idx]) << ", "; + } + } + + file.close(); + } + // permute_b + { + std::ofstream file("ff_b_perm_host.txt"); + int X = static_cast(K); + int Y = static_cast(N); + file << " [b_perm_host]: Row = " << Y << ", Col = " << X << std::endl; + + for(int y = 0; y < Y; y++) + { + file << "\n ========== row : [" << y << " / " << Y << "] =========="; + for(int x = 0; x < X; x++) + { + int idx = X * y + x; + if(idx % 16 == 0) + { + file << "\n [" << x << " : " << x + 15 << " ]: "; + } + + file << ck_tile::type_convert(b_perm_host.mData[idx]) << ", "; + } + } + + file.close(); + } + // d_dev ---> kernel + { + auto d_dev = d_buf.ToHost(); + std::ofstream file("ff_d_dev_kernel.txt"); + file << " [d_dev]: Grid = [" << GridDimX << ", " << GridDimY << "], Block = " << BlockSize + << std::endl; + + for(int bidy = 0; bidy < GridDimY; bidy++) + { + for(int bidx = 0; bidx < GridDimX; bidx++) + { + file << "\n ========== block : [" << bidx << ", " << bidy << "] =========="; + for(int tid = 0; tid < BlockSize; tid++) + { + int gid = (BlockSize * bidx) * bidy + BlockSize * bidx + tid; + + file << "\n [" << tid << "]: "; + for(int i = 0; i < 64; i++) // multi output per thread + file << ck_tile::type_convert(d_dev.mData[gid * 64 + i]) << ", "; + } + } + } + + file.close(); + } + // d_dev + { + auto d_dev = d_buf.ToHost(); + std::ofstream file("ff_d_dev.txt"); + int X = static_cast(N); + int Y = static_cast(M); + file << " [d_dev]: Row = " << Y << ", Col = " << X << std::endl; + + for(int y = 0; y < Y; y++) + { + file << "\n ========== row : [" << y << " / " << Y << "] =========="; + for(int x = 0; x < X; x++) + { + if(x % 64 == 0) + { + file << "\n [" << x << " : " << x + 63 << "]: "; + } + int idx = X * y + x; + file << ck_tile::type_convert(d_dev.mData[idx]) << ", "; + } + } + + file.close(); + } + // d_host + { + std::ofstream file("ff_d_host.txt"); + int X = static_cast(N); + int Y = static_cast(M); + file << " [d_host]: Row = " << Y << ", Col = " << X << std::endl; + + for(int y = 0; y < Y; y++) + { + file << "\n ========== row : [" << y << " / " << Y << "] =========="; + for(int x = 0; x < X; x++) + { + if(x % 64 == 0) + { + file << "\n [" << x << " : " << x + 63 << "]: "; + } + int idx = X * y + x; + file << ck_tile::type_convert(d_host.mData[idx]) << ", "; + } + } + + file.close(); + } +#endif + + std::cout << std::flush << std::endl; + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + std::string prec_i = arg_parser.get_str("prec_i"); + std::string prec_w = arg_parser.get_str("prec_w"); + std::string prec_o = arg_parser.get_str("prec_o"); + std::string prec_st = arg_parser.get_str("prec_st"); + std::string prec_sw = arg_parser.get_str("prec_sw"); + std::string prec_sq = arg_parser.get_str("prec_sq"); + std::string prec_kw = arg_parser.get_str("prec_kw"); + prec_st = (prec_st == "auto") ? "fp32" : prec_st; + prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw; + prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq; + prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw; + + // no dynamic quant case + if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32") + { + return run( + arg_parser) + ? 0 + : -2; + } + else if(prec_i == "fp16" && prec_w == "fp16" && prec_o == "fp16" && prec_kw == "fp32") + { + return run( + arg_parser) + ? 0 + : -2; + } + + return -3; +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 296eb1ecef..897e77115c 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -17,3 +17,4 @@ add_subdirectory(14_moe_smoothquant) add_subdirectory(15_fused_moe) add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) +add_subdirectory(18_flatmm_uk) diff --git a/include/ck_tile/ops/flatmm/block/flatmm_ff_32x512x128_1x4x1_16x16x32.hpp b/include/ck_tile/ops/flatmm/block/flatmm_ff_32x512x128_1x4x1_16x16x32.hpp new file mode 100644 index 0000000000..3bd9dff11a --- /dev/null +++ b/include/ck_tile/ops/flatmm/block/flatmm_ff_32x512x128_1x4x1_16x16x32.hpp @@ -0,0 +1,665 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" + +namespace ck_tile { + +// A async load to LDS, B direct to AGPR +// B matrix preshuffled in br*kr*w +// require 4 wave, occupancy=1c +// agpr useage:256 +// vgpr usage:64(A local) + 64(acc) + 8(os_a) + 8(os_b) = 144 (rem:112) +// +// for this gemm, 4 16x16x16 transposed layout +// input A vpgpr layout +// v0-v15: [ 0:15](gemm_m)x128(gemm_k) +// v16-v31: [16:31](gemm_m)x128(gemm_k) + +// input B vpgpr layout +// v0-v15: [ 0: 15](gemm_n)x128(gemm_k) +// v16-v31: [ 64: 79](gemm_n)x128(gemm_k) +// ...................... +// v111-v127: [448:463](gemm_n)x128(gemm_k) + +// output C vpgpr layout +// v0-v3 : [ 0:15](gemm_m)x[ 0: 15](gemm_n) +// v4-v7 : [16:31](gemm_m)x[ 0: 15](gemm_n) +// v8-v11: [ 0:15](gemm_m)x[64: 79](gemm_n) +// v12-v15: [16:31](gemm_m)x[64: 79](gemm_n) +// ...................... +// v56-v59: [ 0:15](gemm_m)x[448:463](gemm_n) +// v60-v63: [16:31](gemm_m)x[448:463](gemm_n) +struct Flatmm_ff_32x512x128_1x4x1_16x16x32_Base // for f16/bf16 +{ + static constexpr index_t Block_M = 32; + static constexpr index_t Block_N = 512; + static constexpr index_t Block_K = 128; + + static constexpr index_t WarpPerBlock_M = 1; + static constexpr index_t WarpPerBlock_N = 4; + static constexpr index_t WarpPerBlock_K = 1; + + static constexpr index_t NumWarps = 4; + + static constexpr index_t Warp_M = 16; + static constexpr index_t Warp_N = 16; + static constexpr index_t Warp_K = 32; // 16 * SubKPacks + + static constexpr index_t BlockSize = 256; + + static constexpr index_t SubKPacks = 2; // this is used to gurantee every threads can do dwordx4 + + // TODO: note Nr/Kr/W need consider SubKPacks + static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element + static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave + static constexpr index_t Block_Kr = Block_K / Warp_K; // 4 + + static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2 + static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8 + static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4 + + static CK_TILE_DEVICE constexpr auto MakeCBlockDist() + { + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2, 1>, // !! note here is different + sequence<0, 0>>{}; + + using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + return c_block_dstr; + } + + static CK_TILE_DEVICE constexpr auto MakeCBlockTile() + { + using CDataType = float; + constexpr auto c_block_dstr = MakeCBlockDist(); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A() + { + // A async->LDS + // constexpr index_t Block_M = Problem::BlockShape::Block_M0; + // constexpr index_t Block_K = Problem::BlockShape::Block_K0; + // constexpr index_t BlockSize = Problem::BlockShape::BlockSize; + constexpr index_t warpSize = ck_tile::get_warp_size(); + // constexpr index_t NumWarps = Problem::BlockShape::NumWarps; + + constexpr index_t KPack_ = 8; // GetSmemKPack_A(); // LDS + constexpr index_t KVector = 2; // GetAlignment_A(); // async copy 1 dword + constexpr index_t KPad = KPack_; // pad between warps + + static_assert(Block_K % KVector == 0); + constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K + + if constexpr(LanesPerK >= warpSize) + { + // need multiple waves to load K + static_assert(LanesPerK % warpSize == 0); + constexpr index_t wavesPerK = LanesPerK / warpSize; + if constexpr(wavesPerK > NumWarps) + { + // TODO: need multiple issues along K to load all data + } + else + { + constexpr index_t wavesPerM = NumWarps / wavesPerK; + constexpr index_t NumIssues = Block_M / wavesPerM; + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number{}), // k2 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number<1>{}), // k2 + number{}, // lds store vector(actually no explicit store) + number<1>{}); + + constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return lds_block_desc_issues_warps_lanes; + } + } + else + { + // lanes within a wave load different M but same K + static_assert(warpSize % LanesPerK == 0); + constexpr index_t LaneGroups = warpSize / LanesPerK; // along m + constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps); + + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // m2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // m2 + number{}, // k0 + number<1>{}), // k1 + number{}, // lds store vector(actually no explicit store) + number<1>{}); + + constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return lds_block_desc_issues_warps_lanes; + } + } + + // template + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A() + { + // load from LDS to register, every wave has same layout + constexpr index_t KPack_ = 8; // GetSmemKPack_A(); // LDS + constexpr index_t KPad = KPack_; // pad between warps + + constexpr index_t kAMLane = 16; + constexpr index_t kABKLane = 4; + constexpr index_t kABKPerLane = 4; + constexpr index_t kKIter = 2; + static_assert(KPack_ == (kABKPerLane * kKIter)); + + constexpr auto lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, // m0 y + number{}, // m1 p + number{}, // k0 y + number{}, // k1 p + number{}), // k2 y-vector + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number<1>{}), // k2 + number{}, // lds load vector + number<1>{}); + + constexpr auto lds_desc_m_k = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple(make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return lds_desc_m_k; + } + + static constexpr auto GetGemm_AWarpEnc() + { + constexpr index_t kAMLane = 16; + constexpr index_t kABKLane = 4; + constexpr index_t kABKPerLane = 4; + constexpr index_t kKIter = 2; + + using enc_ = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>; + return enc_{}; + } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return 32 * (128 + 8) * sizeof(bf16_t); + } +}; + +struct Flatmm_ff_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_ff_32x512x128_1x4x1_16x16x32_Base +{ + using ADataType = bf16_t; + using BDataType = bf16_t; + + // TODO: need paired with tile_window_linear! + // TODO: need call init_raw() before call this function! + template + CK_TILE_DEVICE auto + operator()(const ARes& res_a, + const ACoords& cached_coords_a, + const BRes& res_b, + const BCoords& cached_coords_b, + CK_TILE_LDS_ADDR void* smem, + index_t k, + index_t tile_offset_a, // for each tile, the offset to move for each unroll + index_t tile_offset_b, // for each tile, the offset to move for each unroll + int * dbg_int, + short* dbg_bf16, + float* dbg_fp32) + { +#if 0 + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[UK] Flatmm_ff_32x512x128_1x4x1_16x16x32_BF16 =====\n"); + } + + [[maybe_unused]] uint32_t tidx = threadIdx.x; // 0~255 + [[maybe_unused]] uint32_t tidy = threadIdx.y; // 0~0 + [[maybe_unused]] uint32_t bidx = blockIdx.x; // 0~1 + [[maybe_unused]] uint32_t bidy = blockIdx.y; // 0~51 + [[maybe_unused]] uint32_t bdmx = blockDim.x; // 256 + [[maybe_unused]] uint32_t bdmy = blockDim.y; // 1 + [[maybe_unused]] uint32_t gdmx = gridDim.x; // 2 + [[maybe_unused]] uint32_t gdmy = gridDim.y; // 52 + [[maybe_unused]] uint32_t gid = ((bdmx * bdmy) * gdmx) * bidy + + (bdmx * bdmy) * bidx + + bdmx * tidy + + tidx; +#endif + (void)dbg_int; + (void)dbg_bf16; + (void)dbg_fp32; + + static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8 + static_assert(BCoords::size() == Repeat_N); + + auto a_sst = make_tile_window( + make_tensor_view( + reinterpret_cast(smem), MakeLdsStoreDesc_A()), + MakeLdsStoreDesc_A().get_lengths(), + {0, 0, 0}); + + auto a_sld = [&]() { + constexpr auto a_warp_enc_ = GetGemm_AWarpEnc(); + constexpr auto a_outer_dstr_enc = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = + detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_); + return make_tile_window_linear( + make_tensor_view( + reinterpret_cast(smem), MakeLdsLoadDesc_A()), + MakeLdsLoadDesc_A().get_lengths(), + {0, 0}, + make_static_tile_distribution(a_block_dstr_encode)); + }(); + + const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType); + const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType); + + const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst); + constexpr auto smem_buf_size = + MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType); + static_assert(a_sld.get_num_of_access() == 8); + constexpr auto sld_os = generate_tuple( + [&](auto i_access) { + return number{}; + }, + number{}); + + index_t loop_cnt = k / Block_K; + + // this is the acc thread buffer + fp32x4_t v_acc[16]{.0f}; + // float dbg_dword = 0.0f; + +#pragma region ASM +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Winline-asm" + // clang-format off + asm volatile( +#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16 +#include "uk/flatmm_ff_uk_gfx9_32x512x128_1x1x1_16x16x16.inc" +#undef CK_TILE_FLATMM_UK_MFMA + : [s_loop_cnt]"+s"(loop_cnt), + [v_acc_0] "+v"(v_acc[0]), + [v_acc_1] "+v"(v_acc[1]), + [v_acc_2] "+v"(v_acc[2]), + [v_acc_3] "+v"(v_acc[3]), + [v_acc_4] "+v"(v_acc[4]), + [v_acc_5] "+v"(v_acc[5]), + [v_acc_6] "+v"(v_acc[6]), + [v_acc_7] "+v"(v_acc[7]), + [v_acc_8] "+v"(v_acc[8]), + [v_acc_9] "+v"(v_acc[9]), + [v_acc_10]"+v"(v_acc[10]), + [v_acc_11]"+v"(v_acc[11]), + [v_acc_12]"+v"(v_acc[12]), + [v_acc_13]"+v"(v_acc[13]), + [v_acc_14]"+v"(v_acc[14]), + [v_acc_15]"+v"(v_acc[15]), + //[v_dbg]"+v"(dbg_dword), + [s_mem_]"+r"(smem) + : [s_res_a0]"s"(res_a[0]), + [s_res_a1]"s"(res_a[1]), + [s_res_a2]"s"(res_a[2]), + [s_res_a3]"s"(res_a[3]), + [v_os_a0]"v"(static_cast(cached_coords_a[number<0>{}] * sizeof(ADataType))), + [v_os_a1]"v"(static_cast(cached_coords_a[number<1>{}] * sizeof(ADataType))), + [v_os_a2]"v"(static_cast(cached_coords_a[number<2>{}] * sizeof(ADataType))), + [v_os_a3]"v"(static_cast(cached_coords_a[number<3>{}] * sizeof(ADataType))), + [v_os_a4]"v"(static_cast(cached_coords_a[number<4>{}] * sizeof(ADataType))), + [v_os_a5]"v"(static_cast(cached_coords_a[number<5>{}] * sizeof(ADataType))), + [v_os_a6]"v"(static_cast(cached_coords_a[number<6>{}] * sizeof(ADataType))), + [v_os_a7]"v"(static_cast(cached_coords_a[number<7>{}] * sizeof(ADataType))), + + [s_res_b0]"s"(res_b[0]), + [s_res_b1]"s"(res_b[1]), + [s_res_b2]"s"(res_b[2]), + [s_res_b3]"s"(res_b[3]), + [v_os_b0]"v"(static_cast(cached_coords_b[number<0>{}] * sizeof(BDataType))), + [v_os_b1]"v"(static_cast(cached_coords_b[number<1>{}] * sizeof(BDataType))), + [v_os_b2]"v"(static_cast(cached_coords_b[number<2>{}] * sizeof(BDataType))), + [v_os_b3]"v"(static_cast(cached_coords_b[number<3>{}] * sizeof(BDataType))), + [v_os_b4]"v"(static_cast(cached_coords_b[number<4>{}] * sizeof(BDataType))), + [v_os_b5]"v"(static_cast(cached_coords_b[number<5>{}] * sizeof(BDataType))), + [v_os_b6]"v"(static_cast(cached_coords_b[number<6>{}] * sizeof(BDataType))), + [v_os_b7]"v"(static_cast(cached_coords_b[number<7>{}] * sizeof(BDataType))), + + [v_os_slda]"v"(static_cast(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))), + [s_m0_init]"s"(m0_init_value), + [s_size_per_issue]"s"(size_per_issue), + [smem_sz]"n"(smem_buf_size), //(smem_buf_size), + [sld_os_0]"n"(sld_os[number<0>{}].value), + [sld_os_1]"n"(sld_os[number<1>{}].value), + [sld_os_2]"n"(sld_os[number<2>{}].value), + [sld_os_3]"n"(sld_os[number<3>{}].value), + [sld_os_4]"n"(sld_os[number<4>{}].value), + [sld_os_5]"n"(sld_os[number<5>{}].value), + [sld_os_6]"n"(sld_os[number<6>{}].value), + [sld_os_7]"n"(sld_os[number<7>{}].value), + [s_tile_os_a]"s"(tile_offset_a_bytes), + [s_tile_os_b]"s"(tile_offset_b_bytes) + : "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", + "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", + "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", + "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", + "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49", + "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59", + "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69", + "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79", + "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89", + "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99", + "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107", + "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115", + "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123", + "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131", + "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139", + "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147", + "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155", + "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163", + "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171", + "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179", + "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187", + "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195", + "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203", + "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211", + "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219", + "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227", + "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235", + "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", + "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", + "a252", "a253", "a254", "a255", + "s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", + "s86", // s86 as tmp + "v64", "v65", "v66", "v67", "v68", "v69", + "v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79", + "v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89", + "v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99", + "v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107", + "v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115", + "v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123", + "v124", "v125", "v126", "v127" + ); + // clang-format on +#pragma clang diagnostic pop +#pragma endregion + + // return local scratch + auto c = MakeCBlockTile(); + for(auto i = 0; i < 16; i++) + { + c.get_thread_buffer()[4 * i + 0] = v_acc[i].x; + c.get_thread_buffer()[4 * i + 1] = v_acc[i].y; + c.get_thread_buffer()[4 * i + 2] = v_acc[i].z; + c.get_thread_buffer()[4 * i + 3] = v_acc[i].w; + } + + /*float * vacc0x = reinterpret_cast(v_acc); + short * pdbgf16 = reinterpret_cast(vacc0x); + //short * pdbg_u8 = reinterpret_cast(&dbg_dword); + int dbgCntPerThd = 8; + for(int i = 0; i < dbgCntPerThd; i++) + { + dbg_bf16[gid * dbgCntPerThd + i] = *(pdbgf16 + i); + //dbg_bf16[gid * dbgCntPerThd + i] = *(pdbg_u8 + i); + } + int * pdbg_i32 = reinterpret_cast(&dbg_dword); + dbg_int[gid] = *(pdbg_i32);*/ + return c; + } +}; + +struct Flatmm_ff_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_ff_32x512x128_1x4x1_16x16x32_Base +{ + using ADataType = fp16_t; + using BDataType = fp16_t; + + // TODO: need paired with tile_window_linear! + // TODO: need call init_raw() before call this function! + template + CK_TILE_DEVICE auto + operator()(const ARes& res_a, + const ACoords& cached_coords_a, + const BRes& res_b, + const BCoords& cached_coords_b, + CK_TILE_LDS_ADDR void* smem, + index_t k, + index_t tile_offset_a, // for each tile, the offset to move for each unroll + index_t tile_offset_b, + int * dbg_int, + short* dbg_bf16, + float* dbg_fp32) + { + (void)dbg_int; + (void)dbg_fp32; + (void)dbg_bf16; + static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8 + static_assert(BCoords::size() == Repeat_N); + + auto a_sst = make_tile_window( + make_tensor_view( + reinterpret_cast(smem), MakeLdsStoreDesc_A()), + MakeLdsStoreDesc_A().get_lengths(), + {0, 0, 0}); + + auto a_sld = [&]() { + constexpr auto a_warp_enc_ = GetGemm_AWarpEnc(); + constexpr auto a_outer_dstr_enc = tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + constexpr auto a_block_dstr_encode = + detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_); + return make_tile_window_linear( + make_tensor_view( + reinterpret_cast(smem), MakeLdsLoadDesc_A()), + MakeLdsLoadDesc_A().get_lengths(), + {0, 0}, + make_static_tile_distribution(a_block_dstr_encode)); + }(); + + const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType); + const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType); + + const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst); + constexpr auto smem_buf_size = + MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType); + static_assert(a_sld.get_num_of_access() == 8); + constexpr auto sld_os = generate_tuple( + [&](auto i_access) { + return number{}; + }, + number{}); + + index_t loop_cnt = k / Block_K; + + // this is the acc thread buffer + fp32x4_t v_acc[16]{.0f}; + float dbg_dword = 0.0f; + + // B nr->kr +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Winline-asm" + // clang-format off + asm volatile( +#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16 +#include "uk/flatmm_ff_uk_gfx9_32x512x128_1x1x1_16x16x16.inc" +#undef CK_TILE_FLATMM_UK_MFMA + : [s_loop_cnt]"+s"(loop_cnt), + [v_acc_0]"+v"(v_acc[0]), + [v_acc_1]"+v"(v_acc[1]), + [v_acc_2]"+v"(v_acc[2]), + [v_acc_3]"+v"(v_acc[3]), + [v_acc_4]"+v"(v_acc[4]), + [v_acc_5]"+v"(v_acc[5]), + [v_acc_6]"+v"(v_acc[6]), + [v_acc_7]"+v"(v_acc[7]), + [v_acc_8]"+v"(v_acc[8]), + [v_acc_9]"+v"(v_acc[9]), + [v_acc_10]"+v"(v_acc[10]), + [v_acc_11]"+v"(v_acc[11]), + [v_acc_12]"+v"(v_acc[12]), + [v_acc_13]"+v"(v_acc[13]), + [v_acc_14]"+v"(v_acc[14]), + [v_acc_15]"+v"(v_acc[15]), + [s_mem_]"+r"(smem) + : [s_res_a0]"s"(res_a[0]), + [s_res_a1]"s"(res_a[1]), + [s_res_a2]"s"(res_a[2]), + [s_res_a3]"s"(res_a[3]), + [s_res_b0]"s"(res_b[0]), + [s_res_b1]"s"(res_b[1]), + [s_res_b2]"s"(res_b[2]), + [s_res_b3]"s"(res_b[3]), + [v_os_a0]"v"(static_cast(cached_coords_a[number<0>{}] * sizeof(ADataType))), + [v_os_a1]"v"(static_cast(cached_coords_a[number<1>{}] * sizeof(ADataType))), + [v_os_a2]"v"(static_cast(cached_coords_a[number<2>{}] * sizeof(ADataType))), + [v_os_a3]"v"(static_cast(cached_coords_a[number<3>{}] * sizeof(ADataType))), + [v_os_a4]"v"(static_cast(cached_coords_a[number<4>{}] * sizeof(ADataType))), + [v_os_a5]"v"(static_cast(cached_coords_a[number<5>{}] * sizeof(ADataType))), + [v_os_a6]"v"(static_cast(cached_coords_a[number<6>{}] * sizeof(ADataType))), + [v_os_a7]"v"(static_cast(cached_coords_a[number<7>{}] * sizeof(ADataType))), + + [v_os_b0]"v"(static_cast(cached_coords_b[number<0>{}] * sizeof(BDataType))), + [v_os_b1]"v"(static_cast(cached_coords_b[number<1>{}] * sizeof(BDataType))), + [v_os_b2]"v"(static_cast(cached_coords_b[number<2>{}] * sizeof(BDataType))), + [v_os_b3]"v"(static_cast(cached_coords_b[number<3>{}] * sizeof(BDataType))), + [v_os_b4]"v"(static_cast(cached_coords_b[number<4>{}] * sizeof(BDataType))), + [v_os_b5]"v"(static_cast(cached_coords_b[number<5>{}] * sizeof(BDataType))), + [v_os_b6]"v"(static_cast(cached_coords_b[number<6>{}] * sizeof(BDataType))), + [v_os_b7]"v"(static_cast(cached_coords_b[number<7>{}] * sizeof(BDataType))), + [v_dbg]"v"(dbg_dword), + + [v_os_slda]"v"(static_cast(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))), + [s_m0_init]"s"(m0_init_value), + [s_size_per_issue]"s"(size_per_issue), + [smem_sz]"n"(smem_buf_size), //(smem_buf_size), + [sld_os_0]"n"(sld_os[number<0>{}].value), + [sld_os_1]"n"(sld_os[number<1>{}].value), + [sld_os_2]"n"(sld_os[number<2>{}].value), + [sld_os_3]"n"(sld_os[number<3>{}].value), + [sld_os_4]"n"(sld_os[number<4>{}].value), + [sld_os_5]"n"(sld_os[number<5>{}].value), + [sld_os_6]"n"(sld_os[number<6>{}].value), + [sld_os_7]"n"(sld_os[number<7>{}].value), + [s_tile_os_a]"s"(tile_offset_a_bytes), + [s_tile_os_b]"s"(tile_offset_b_bytes) + : "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9", + "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", + "a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29", + "a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39", + "a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49", + "a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59", + "a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69", + "a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79", + "a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89", + "a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99", + "a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107", + "a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115", + "a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123", + "a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131", + "a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139", + "a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147", + "a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155", + "a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163", + "a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171", + "a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179", + "a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187", + "a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195", + "a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203", + "a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211", + "a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219", + "a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227", + "a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235", + "a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243", + "a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251", + "a252", "a253", "a254", "a255", + "s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23", + "s86", // s86 as tmp + "v64", "v65", "v66", "v67", "v68", "v69", + "v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79", + "v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89", + "v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99", + "v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107", + "v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115", + "v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123", + "v124", "v125", "v126", "v127" + ); + // clang-format on +#pragma clang diagnostic pop + + // return local scratch + auto c = MakeCBlockTile(); + for(auto i = 0; i < 16; i++) + { + c.get_thread_buffer()[4 * i + 0] = v_acc[i].x; + c.get_thread_buffer()[4 * i + 1] = v_acc[i].y; + c.get_thread_buffer()[4 * i + 2] = v_acc[i].z; + c.get_thread_buffer()[4 * i + 3] = v_acc[i].w; + } + return c; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/block/uk/flatmm_ff_uk_gfx9_32x512x128_1x1x1_16x16x16.inc b/include/ck_tile/ops/flatmm/block/uk/flatmm_ff_uk_gfx9_32x512x128_1x1x1_16x16x16.inc new file mode 100644 index 0000000000..58f735e769 --- /dev/null +++ b/include/ck_tile/ops/flatmm/block/uk/flatmm_ff_uk_gfx9_32x512x128_1x1x1_16x16x16.inc @@ -0,0 +1,574 @@ +#ifndef CK_TILE_FLATMM_UK_MFMA +#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16 +#endif + +#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_BF16 +#define _UK_MFMA_ "v_mfma_f32_16x16x16_bf16" +#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16 +#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16" +#endif + + + + "s_mov_b32 s16, %[s_res_a0] \n" + "s_mov_b32 s17, %[s_res_a1] \n" + "s_mov_b32 s18, %[s_res_a2] \n" + "s_mov_b32 s19, %[s_res_a3] \n" + "s_mov_b32 s20, %[s_res_b0] \n" + "s_mov_b32 s21, %[s_res_b1] \n" + "s_mov_b32 s22, %[s_res_b2] \n" + "s_mov_b32 s23, %[s_res_b3] \n" + // "s_nop 4\n" + "; -- prefetch A0\n" + "s_add_u32 m0, 0, %[s_m0_init] \n" + "buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n" + "s_add_u32 m0, %[s_size_per_issue], m0 \n" // size_per_issue = 1088 + "buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n" + "s_add_u32 m0, %[s_size_per_issue], m0 \n" + "buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n" + "s_add_u32 m0, %[s_size_per_issue], m0 \n" + "buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n" + "s_add_u32 m0, %[s_size_per_issue], m0 \n" + "buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n" + "s_add_u32 m0, %[s_size_per_issue], m0 \n" + "buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n" + "s_add_u32 m0, %[s_size_per_issue], m0 \n" + "buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n" + "s_add_u32 m0, %[s_size_per_issue], m0 \n" + "buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n" + + "s_add_u32 m0, %[smem_sz], %[s_m0_init] \n" // smem_sz = 8688 + "s_cmp_gt_i32 %[s_loop_cnt] 1 ; move a with cond \n" + "s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n" // tile_offset_a_bytes = 256 = 128(bf16)K + "s_add_u32 s16, s86, s16 ; move a with cond \n" + "s_addc_u32 s17, 0, s17 ; move a with cond \n" + + "; -- prefetch A1\n" + "buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n" + "s_add_u32 m0, %[s_size_per_issue], m0 \n" + "buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n" + "s_add_u32 m0, %[s_size_per_issue], m0 \n" + "buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n" + "s_add_u32 m0, %[s_size_per_issue], m0 \n" + "buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n" + "s_add_u32 m0, %[s_size_per_issue], m0 \n" + "buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n" + "s_add_u32 m0, %[s_size_per_issue], m0 \n" + "buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n" + "s_add_u32 m0, %[s_size_per_issue], m0 \n" + "buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n" + "s_add_u32 m0, %[s_size_per_issue], m0 \n" + "buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n" + + "s_add_u32 m0, 0, %[s_m0_init] \n" + "s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n" + "s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n" + "s_add_u32 s16, s86, s16 ; move a with cond \n" + "s_addc_u32 s17, 0, s17 ; move a with cond \n" + + "; -- prefetch B0\n" + "buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n" // bf16 * 8 + "buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n" + "buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n" + "buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n" + "buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n" + "buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n" + "buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n" + "buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n" + "buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n" + "buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n" + "buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n" + "buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n" + "buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n" + "buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n" + "buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n" + "buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n" + "buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n" + "buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n" + "buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n" + "buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n" + "buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n" + "buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n" + "buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n" + "buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n" + "buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n" + "buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n" + "buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n" + "buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n" + "buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n" + "buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n" + "buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n" + "buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n" + +#if 0 + // test a + //" s_waitcnt vmcnt(0) & lgkmcnt(0) \n" + //" s_barrier \n" + //"ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]\n" // 1024: N stride, 64 // K stride + //"ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]\n" + //"ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]\n" + //"ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]\n" + //"ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]\n" + //"ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]\n" + //"ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]\n" + //"ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]\n" + //" s_waitcnt vmcnt(0) & lgkmcnt(0) \n" + //" s_barrier \n" + //"v_add_u32 %[v_dbg], %[v_os_slda], %[sld_os_0] \n" + //"v_mov_b32 %[v_dbg], v64 \n" + + // test b + " s_waitcnt vmcnt(0) & lgkmcnt(0) \n" + " s_barrier \n" + //"s_add_u32 s20, s86, s20 ; move b with cond \n" + //"s_addc_u32 s21, 0, s21 ; move b with cond \n" + "buffer_load_dwordx4 %[v_acc_0], %[v_os_b0], s[20:23], 0 offen offset:1024 \n" + " s_waitcnt vmcnt(0) & lgkmcnt(0) \n" + " s_barrier \n" + "v_mov_b32 %[v_dbg], %[v_os_b2] \n" +#else + + "s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n" + "s_cselect_b32 s86, %[s_tile_os_b], 0 ; move b with cond \n" // s_tile_os_b = 4096B = 2048(bf16)K + "s_add_u32 s20, s86, s20 ; move b with cond \n" + "s_addc_u32 s21, 0, s21 ; move b with cond \n" + "s_waitcnt vmcnt(40) \n" + "s_barrier \n" + + "ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]\n" // 1024: N stride, 64 // K stride + "ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]\n" + "ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]\n" + "ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]\n" + "ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]\n" + "ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]\n" + "ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]\n" + "ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]\n" + + ////////////////////////////////////////////////////////////////////// + "L_start%=: \n" + " s_waitcnt vmcnt(24) & lgkmcnt(0) \n" + " s_barrier \n" + _UK_MFMA_ " %[v_acc_0], acc[0:1], v[64:65], %[v_acc_0] \n" + _UK_MFMA_ " %[v_acc_0], acc[2:3], v[66:67], %[v_acc_0] \n" + " buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[20:23], 0 offen \n" + _UK_MFMA_ " %[v_acc_0], acc[4:5], v[68:69], %[v_acc_0] \n" + _UK_MFMA_ " %[v_acc_0], acc[6:7], v[70:71], %[v_acc_0] \n" + " buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n" + " s_add_u32 m0, %[s_size_per_issue], m0 \n" + _UK_MFMA_ " %[v_acc_0], acc[8:9], v[72:73], %[v_acc_0] \n" + _UK_MFMA_ " %[v_acc_0], acc[10:11], v[74:75], %[v_acc_0] \n" + " buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[20:23], 0 offen offset:1024 \n" + _UK_MFMA_ " %[v_acc_0], acc[12:13], v[76:77], %[v_acc_0] \n" + _UK_MFMA_ " %[v_acc_0], acc[14:15], v[78:79], %[v_acc_0] \n" + " buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n" + " s_add_u32 m0, %[s_size_per_issue], m0 \n" + + _UK_MFMA_ " %[v_acc_1], acc[0:1], v[80:81], %[v_acc_1] \n" + _UK_MFMA_ " %[v_acc_1], acc[2:3], v[82:83], %[v_acc_1] \n" + " buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[20:23], 0 offen offset:2048 \n" + _UK_MFMA_ " %[v_acc_1], acc[4:5], v[84:85], %[v_acc_1] \n" + _UK_MFMA_ " %[v_acc_1], acc[6:7], v[86:87], %[v_acc_1] \n" + " buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n" + " s_add_u32 m0, %[s_size_per_issue], m0 \n" + _UK_MFMA_ " %[v_acc_1], acc[8:9], v[88:89], %[v_acc_1] \n" + _UK_MFMA_ " %[v_acc_1], acc[10:11], v[90:91], %[v_acc_1] \n" + " buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[20:23], 0 offen offset:3072 \n" + _UK_MFMA_ " %[v_acc_1], acc[12:13], v[92:93], %[v_acc_1] \n" + _UK_MFMA_ " %[v_acc_1], acc[14:15], v[94:95], %[v_acc_1] \n" + " buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n" + " s_add_u32 m0, %[s_size_per_issue], m0 \n" + + _UK_MFMA_ " %[v_acc_2], acc[16:17], v[64:65], %[v_acc_2] \n" + _UK_MFMA_ " %[v_acc_2], acc[18:19], v[66:67], %[v_acc_2] \n" + " buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[20:23], 0 offen \n" + _UK_MFMA_ " %[v_acc_2], acc[20:21], v[68:69], %[v_acc_2] \n" + _UK_MFMA_ " %[v_acc_2], acc[22:23], v[70:71], %[v_acc_2] \n" + " buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n" + " s_add_u32 m0, %[s_size_per_issue], m0 \n" + _UK_MFMA_ " %[v_acc_2], acc[24:25], v[72:73], %[v_acc_2] \n" + _UK_MFMA_ " %[v_acc_2], acc[26:27], v[74:75], %[v_acc_2] \n" + " buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[20:23], 0 offen offset:1024 \n" + _UK_MFMA_ " %[v_acc_2], acc[28:29], v[76:77], %[v_acc_2] \n" + _UK_MFMA_ " %[v_acc_2], acc[30:31], v[78:79], %[v_acc_2] \n" + " buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n" + " s_add_u32 m0, %[s_size_per_issue], m0 \n" + + _UK_MFMA_ " %[v_acc_3], acc[16:17], v[80:81], %[v_acc_3] \n" + _UK_MFMA_ " %[v_acc_3], acc[18:19], v[82:83], %[v_acc_3] \n" + " buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[20:23], 0 offen offset:2048 \n" + _UK_MFMA_ " %[v_acc_3], acc[20:21], v[84:85], %[v_acc_3] \n" + _UK_MFMA_ " %[v_acc_3], acc[22:23], v[86:87], %[v_acc_3] \n" + " buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n" + " s_add_u32 m0, %[s_size_per_issue], m0 \n" + _UK_MFMA_ " %[v_acc_3], acc[24:25], v[88:89], %[v_acc_3] \n" + _UK_MFMA_ " %[v_acc_3], acc[26:27], v[90:91], %[v_acc_3] \n" + " buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[20:23], 0 offen offset:3072 \n" + _UK_MFMA_ " %[v_acc_3], acc[28:29], v[92:93], %[v_acc_3] \n" + _UK_MFMA_ " %[v_acc_3], acc[30:31], v[94:95], %[v_acc_3] \n" + " buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n" + " s_add_u32 m0, %[smem_sz], %[s_m0_init] \n" + " s_waitcnt vmcnt(32) \n" + + _UK_MFMA_ " %[v_acc_4], acc[32:33], v[64:65], %[v_acc_4] \n" + _UK_MFMA_ " %[v_acc_4], acc[34:35], v[66:67], %[v_acc_4] \n" + " buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[20:23], 0 offen \n" + _UK_MFMA_ " %[v_acc_4], acc[36:37], v[68:69], %[v_acc_4] \n" + _UK_MFMA_ " %[v_acc_4], acc[38:39], v[70:71], %[v_acc_4] \n" + " ds_read_b128 v[96:99], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_0] \n" + _UK_MFMA_ " %[v_acc_4], acc[40:41], v[72:73], %[v_acc_4] \n" + _UK_MFMA_ " %[v_acc_4], acc[42:43], v[74:75], %[v_acc_4] \n" + " buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[20:23], 0 offen offset:1024 \n" + _UK_MFMA_ " %[v_acc_4], acc[44:45], v[76:77], %[v_acc_4] \n" + _UK_MFMA_ " %[v_acc_4], acc[46:47], v[78:79], %[v_acc_4] \n" + " ds_read_b128 v[100:103], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_1] \n" + + _UK_MFMA_ " %[v_acc_5], acc[32:33], v[80:81], %[v_acc_5] \n" + _UK_MFMA_ " %[v_acc_5], acc[34:35], v[82:83], %[v_acc_5] \n" + " buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[20:23], 0 offen offset:2048 \n" + _UK_MFMA_ " %[v_acc_5], acc[36:37], v[84:85], %[v_acc_5] \n" + _UK_MFMA_ " %[v_acc_5], acc[38:39], v[86:87], %[v_acc_5] \n" + " ds_read_b128 v[104:107], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_2] \n" + _UK_MFMA_ " %[v_acc_5], acc[40:41], v[88:89], %[v_acc_5] \n" + _UK_MFMA_ " %[v_acc_5], acc[42:43], v[90:91], %[v_acc_5] \n" + " buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[20:23], 0 offen offset:3072 \n" + _UK_MFMA_ " %[v_acc_5], acc[44:45], v[92:93], %[v_acc_5] \n" + _UK_MFMA_ " %[v_acc_5], acc[46:47], v[94:95], %[v_acc_5] \n" + " ds_read_b128 v[108:111], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_3] \n" + + _UK_MFMA_ " %[v_acc_6], acc[48:49], v[64:65], %[v_acc_6] \n" + _UK_MFMA_ " %[v_acc_6], acc[50:51], v[66:67], %[v_acc_6] \n" + " buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[20:23], 0 offen \n" + _UK_MFMA_ " %[v_acc_6], acc[52:53], v[68:69], %[v_acc_6] \n" + _UK_MFMA_ " %[v_acc_6], acc[54:55], v[70:71], %[v_acc_6] \n" + " ds_read_b128 v[112:115], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_4] \n" + _UK_MFMA_ " %[v_acc_6], acc[56:57], v[72:73], %[v_acc_6] \n" + _UK_MFMA_ " %[v_acc_6], acc[58:59], v[74:75], %[v_acc_6] \n" + " buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[20:23], 0 offen offset:1024 \n" + _UK_MFMA_ " %[v_acc_6], acc[60:61], v[76:77], %[v_acc_6] \n" + _UK_MFMA_ " %[v_acc_6], acc[62:63], v[78:79], %[v_acc_6] \n" + " ds_read_b128 v[116:119], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_5] \n" + + _UK_MFMA_ " %[v_acc_7], acc[48:49], v[80:81], %[v_acc_7] \n" + _UK_MFMA_ " %[v_acc_7], acc[50:51], v[82:83], %[v_acc_7] \n" + " buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[20:23], 0 offen offset:2048 \n" + _UK_MFMA_ " %[v_acc_7], acc[52:53], v[84:85], %[v_acc_7] \n" + _UK_MFMA_ " %[v_acc_7], acc[54:55], v[86:87], %[v_acc_7] \n" + " ds_read_b128 v[120:123], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_6] \n" + _UK_MFMA_ " %[v_acc_7], acc[56:57], v[88:89], %[v_acc_7] \n" + _UK_MFMA_ " %[v_acc_7], acc[58:59], v[90:91], %[v_acc_7] \n" + " buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[20:23], 0 offen offset:3072 \n" + _UK_MFMA_ " %[v_acc_7], acc[60:61], v[92:93], %[v_acc_7] \n" + _UK_MFMA_ " %[v_acc_7], acc[62:63], v[94:95], %[v_acc_7] \n" + " ds_read_b128 v[124:127], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_7] \n" + " s_waitcnt vmcnt(32) \n" + + _UK_MFMA_ " %[v_acc_8], acc[64:65], v[64:65], %[v_acc_8] \n" + _UK_MFMA_ " %[v_acc_8], acc[66:67], v[66:67], %[v_acc_8] \n" + " buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[20:23], 0 offen \n" + _UK_MFMA_ " %[v_acc_8], acc[68:69], v[68:69], %[v_acc_8] \n" + _UK_MFMA_ " %[v_acc_8], acc[70:71], v[70:71], %[v_acc_8] \n" + _UK_MFMA_ " %[v_acc_8], acc[72:73], v[72:73], %[v_acc_8] \n" + _UK_MFMA_ " %[v_acc_8], acc[74:75], v[74:75], %[v_acc_8] \n" + " buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[20:23], 0 offen offset:1024 \n" + _UK_MFMA_ " %[v_acc_8], acc[76:77], v[76:77], %[v_acc_8] \n" + _UK_MFMA_ " %[v_acc_8], acc[78:79], v[78:79], %[v_acc_8] \n" + + _UK_MFMA_ " %[v_acc_9], acc[64:65], v[80:81], %[v_acc_9] \n" + _UK_MFMA_ " %[v_acc_9], acc[66:67], v[82:83], %[v_acc_9] \n" + " buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[20:23], 0 offen offset:2048 \n" + _UK_MFMA_ " %[v_acc_9], acc[68:69], v[84:85], %[v_acc_9] \n" + _UK_MFMA_ " %[v_acc_9], acc[70:71], v[86:87], %[v_acc_9] \n" + _UK_MFMA_ " %[v_acc_9], acc[72:73], v[88:89], %[v_acc_9] \n" + _UK_MFMA_ " %[v_acc_9], acc[74:75], v[90:91], %[v_acc_9] \n" + " buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[20:23], 0 offen offset:3072 \n" + _UK_MFMA_ " %[v_acc_9], acc[76:77], v[92:93], %[v_acc_9] \n" + _UK_MFMA_ " %[v_acc_9], acc[78:79], v[94:95], %[v_acc_9] \n" + + _UK_MFMA_ " %[v_acc_10], acc[80:81], v[64:65], %[v_acc_10] \n" + _UK_MFMA_ " %[v_acc_10], acc[82:83], v[66:67], %[v_acc_10] \n" + " buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[20:23], 0 offen \n" + _UK_MFMA_ " %[v_acc_10], acc[84:85], v[68:69], %[v_acc_10] \n" + _UK_MFMA_ " %[v_acc_10], acc[86:87], v[70:71], %[v_acc_10] \n" + _UK_MFMA_ " %[v_acc_10], acc[88:89], v[72:73], %[v_acc_10] \n" + _UK_MFMA_ " %[v_acc_10], acc[90:91], v[74:75], %[v_acc_10] \n" + " buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[20:23], 0 offen offset:1024 \n" + _UK_MFMA_ " %[v_acc_10], acc[92:93], v[76:77], %[v_acc_10] \n" + _UK_MFMA_ " %[v_acc_10], acc[94:95], v[78:79], %[v_acc_10] \n" + + _UK_MFMA_ " %[v_acc_11], acc[80:81], v[80:81], %[v_acc_11] \n" + _UK_MFMA_ " %[v_acc_11], acc[82:83], v[82:83], %[v_acc_11] \n" + " buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[20:23], 0 offen offset:2048 \n" + _UK_MFMA_ " %[v_acc_11], acc[84:85], v[84:85], %[v_acc_11] \n" + _UK_MFMA_ " %[v_acc_11], acc[86:87], v[86:87], %[v_acc_11] \n" + _UK_MFMA_ " %[v_acc_11], acc[88:89], v[88:89], %[v_acc_11] \n" + _UK_MFMA_ " %[v_acc_11], acc[90:91], v[90:91], %[v_acc_11] \n" + " buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[20:23], 0 offen offset:3072 \n" + _UK_MFMA_ " %[v_acc_11], acc[92:93], v[92:93], %[v_acc_11] \n" + _UK_MFMA_ " %[v_acc_11], acc[94:95], v[94:95], %[v_acc_11] \n" + " s_waitcnt vmcnt(32) \n" + + _UK_MFMA_ " %[v_acc_12], acc[96:97], v[64:65], %[v_acc_12] \n" + _UK_MFMA_ " %[v_acc_12], acc[98:99], v[66:67], %[v_acc_12] \n" + " buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[20:23], 0 offen \n" + _UK_MFMA_ " %[v_acc_12], acc[100:101], v[68:69], %[v_acc_12] \n" + _UK_MFMA_ " %[v_acc_12], acc[102:103], v[70:71], %[v_acc_12] \n" + _UK_MFMA_ " %[v_acc_12], acc[104:105], v[72:73], %[v_acc_12] \n" + _UK_MFMA_ " %[v_acc_12], acc[106:107], v[74:75], %[v_acc_12] \n" + " buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[20:23], 0 offen offset:1024 \n" + _UK_MFMA_ " %[v_acc_12], acc[108:109], v[76:77], %[v_acc_12] \n" + _UK_MFMA_ " %[v_acc_12], acc[110:111], v[78:79], %[v_acc_12] \n" + + _UK_MFMA_ " %[v_acc_13], acc[96:97], v[80:81], %[v_acc_13] \n" + _UK_MFMA_ " %[v_acc_13], acc[98:99], v[82:83], %[v_acc_13] \n" + " buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[20:23], 0 offen offset:2048 \n" + _UK_MFMA_ " %[v_acc_13], acc[100:101], v[84:85], %[v_acc_13] \n" + _UK_MFMA_ " %[v_acc_13], acc[102:103], v[86:87], %[v_acc_13] \n" + _UK_MFMA_ " %[v_acc_13], acc[104:105], v[88:89], %[v_acc_13] \n" + _UK_MFMA_ " %[v_acc_13], acc[106:107], v[90:91], %[v_acc_13] \n" + " buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[20:23], 0 offen offset:3072 \n" + _UK_MFMA_ " %[v_acc_13], acc[108:109], v[92:93], %[v_acc_13] \n" + _UK_MFMA_ " %[v_acc_13], acc[110:111], v[94:95], %[v_acc_13] \n" + + _UK_MFMA_ " %[v_acc_14], acc[112:113], v[64:65], %[v_acc_14] \n" + _UK_MFMA_ " %[v_acc_14], acc[114:115], v[66:67], %[v_acc_14] \n" + " buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[20:23], 0 offen \n" + _UK_MFMA_ " %[v_acc_14], acc[116:117], v[68:69], %[v_acc_14] \n" + _UK_MFMA_ " %[v_acc_14], acc[118:119], v[70:71], %[v_acc_14] \n" + _UK_MFMA_ " %[v_acc_14], acc[120:121], v[72:73], %[v_acc_14] \n" + _UK_MFMA_ " %[v_acc_14], acc[122:123], v[74:75], %[v_acc_14] \n" + " buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[20:23], 0 offen offset:1024 \n" + _UK_MFMA_ " %[v_acc_14], acc[124:125], v[76:77], %[v_acc_14] \n" + _UK_MFMA_ " %[v_acc_14], acc[126:127], v[78:79], %[v_acc_14] \n" + + _UK_MFMA_ " %[v_acc_15], acc[112:113], v[80:81], %[v_acc_15] \n" + _UK_MFMA_ " %[v_acc_15], acc[114:115], v[82:83], %[v_acc_15] \n" + " buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[20:23], 0 offen offset:2048 \n" + _UK_MFMA_ " %[v_acc_15], acc[116:117], v[84:85], %[v_acc_15] \n" + _UK_MFMA_ " %[v_acc_15], acc[118:119], v[86:87], %[v_acc_15] \n" + _UK_MFMA_ " %[v_acc_15], acc[120:121], v[88:89], %[v_acc_15] \n" + _UK_MFMA_ " %[v_acc_15], acc[122:123], v[90:91], %[v_acc_15] \n" + " buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[20:23], 0 offen offset:3072\n" + _UK_MFMA_ " %[v_acc_15], acc[124:125], v[92:93], %[v_acc_15] \n" + _UK_MFMA_ " %[v_acc_15], acc[126:127], v[94:95], %[v_acc_15] \n" + //////////////////////////////////////////////////// + " s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n" + " s_cmp_gt_i32 %[s_loop_cnt] 0 \n" + " s_cbranch_scc0 L_end%= \n" + " s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n" + " s_cselect_b32 s86, %[s_tile_os_a], 0 \n" + " s_add_u32 s16, s86, s16 \n" + " s_addc_u32 s17, 0, s17 \n" + " s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n" + " s_cselect_b32 s86, %[s_tile_os_b], 0 \n" + " s_add_u32 s20, s86, s20 \n" + " s_addc_u32 s21, 0, s21 \n" + " ;------------------------------------------ \n" + " s_waitcnt vmcnt(24) & lgkmcnt(0) \n" + " s_barrier \n" + /////////////////////////////////////////////////////// + _UK_MFMA_ " %[v_acc_0], acc[128:129], v[96:97], %[v_acc_0] \n" + _UK_MFMA_ " %[v_acc_0], acc[130:131], v[98:99], %[v_acc_0] \n" + " buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n" + _UK_MFMA_ " %[v_acc_0], acc[132:133], v[100:101], %[v_acc_0] \n" + _UK_MFMA_ " %[v_acc_0], acc[134:135], v[102:103], %[v_acc_0] \n" + " buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n" + " s_add_u32 m0, %[s_size_per_issue], m0 \n" + _UK_MFMA_ " %[v_acc_0], acc[136:137], v[104:105], %[v_acc_0] \n" + _UK_MFMA_ " %[v_acc_0], acc[138:139], v[106:107], %[v_acc_0] \n" + " buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n" + _UK_MFMA_ " %[v_acc_0], acc[140:141], v[108:109], %[v_acc_0] \n" + _UK_MFMA_ " %[v_acc_0], acc[142:143], v[110:111], %[v_acc_0] \n" + " buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n" + " s_add_u32 m0, %[s_size_per_issue], m0 \n" + _UK_MFMA_ " %[v_acc_1], acc[128:129], v[112:113], %[v_acc_1] \n" + _UK_MFMA_ " %[v_acc_1], acc[130:131], v[114:115], %[v_acc_1] \n" + " buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n" + _UK_MFMA_ " %[v_acc_1], acc[132:133], v[116:117], %[v_acc_1] \n" + _UK_MFMA_ " %[v_acc_1], acc[134:135], v[118:119], %[v_acc_1] \n" + " buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n" + " s_add_u32 m0, %[s_size_per_issue], m0 \n" + _UK_MFMA_ " %[v_acc_1], acc[136:137], v[120:121], %[v_acc_1] \n" + _UK_MFMA_ " %[v_acc_1], acc[138:139], v[122:123], %[v_acc_1] \n" + " buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n" + _UK_MFMA_ " %[v_acc_1], acc[140:141], v[124:125], %[v_acc_1] \n" + _UK_MFMA_ " %[v_acc_1], acc[142:143], v[126:127], %[v_acc_1] \n" + " buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n" + " s_add_u32 m0, %[s_size_per_issue], m0 \n" + _UK_MFMA_ " %[v_acc_2], acc[144:145], v[96:97], %[v_acc_2] \n" + _UK_MFMA_ " %[v_acc_2], acc[146:147], v[98:99], %[v_acc_2] \n" + " buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n" + _UK_MFMA_ " %[v_acc_2], acc[148:149], v[100:101], %[v_acc_2] \n" + _UK_MFMA_ " %[v_acc_2], acc[150:151], v[102:103], %[v_acc_2] \n" + " buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n" + " s_add_u32 m0, %[s_size_per_issue], m0 \n" + _UK_MFMA_ " %[v_acc_2], acc[152:153], v[104:105], %[v_acc_2] \n" + _UK_MFMA_ " %[v_acc_2], acc[154:155], v[106:107], %[v_acc_2] \n" + " buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n" + _UK_MFMA_ " %[v_acc_2], acc[156:157], v[108:109], %[v_acc_2] \n" + _UK_MFMA_ " %[v_acc_2], acc[158:159], v[110:111], %[v_acc_2] \n" + " buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n" + " s_add_u32 m0, %[s_size_per_issue], m0 \n" + _UK_MFMA_ " %[v_acc_3], acc[144:145], v[112:113], %[v_acc_3] \n" + _UK_MFMA_ " %[v_acc_3], acc[146:147], v[114:115], %[v_acc_3] \n" + " buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n" + _UK_MFMA_ " %[v_acc_3], acc[148:149], v[116:117], %[v_acc_3] \n" + _UK_MFMA_ " %[v_acc_3], acc[150:151], v[118:119], %[v_acc_3] \n" + " buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n" + " s_add_u32 m0, %[s_size_per_issue], m0 \n" + _UK_MFMA_ " %[v_acc_3], acc[152:153], v[120:121], %[v_acc_3] \n" + _UK_MFMA_ " %[v_acc_3], acc[154:155], v[122:123], %[v_acc_3] \n" + " buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n" + _UK_MFMA_ " %[v_acc_3], acc[156:157], v[124:125], %[v_acc_3] \n" + _UK_MFMA_ " %[v_acc_3], acc[158:159], v[126:127], %[v_acc_3] \n" + " buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n" + " s_add_u32 m0, 0, %[s_m0_init] \n" + " s_waitcnt vmcnt(32) \n" + _UK_MFMA_ " %[v_acc_4], acc[160:161], v[96:97], %[v_acc_4] \n" + _UK_MFMA_ " %[v_acc_4], acc[162:163], v[98:99], %[v_acc_4] \n" + " buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n" + _UK_MFMA_ " %[v_acc_4], acc[164:165], v[100:101], %[v_acc_4] \n" + _UK_MFMA_ " %[v_acc_4], acc[166:167], v[102:103], %[v_acc_4] \n" + " ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0] \n" + _UK_MFMA_ " %[v_acc_4], acc[168:169], v[104:105], %[v_acc_4] \n" + _UK_MFMA_ " %[v_acc_4], acc[170:171], v[106:107], %[v_acc_4] \n" + " buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n" + _UK_MFMA_ " %[v_acc_4], acc[172:173], v[108:109], %[v_acc_4] \n" + _UK_MFMA_ " %[v_acc_4], acc[174:175], v[110:111], %[v_acc_4] \n" + " ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1] \n" + _UK_MFMA_ " %[v_acc_5], acc[160:161], v[112:113], %[v_acc_5] \n" + _UK_MFMA_ " %[v_acc_5], acc[162:163], v[114:115], %[v_acc_5] \n" + " buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n" + _UK_MFMA_ " %[v_acc_5], acc[164:165], v[116:117], %[v_acc_5] \n" + _UK_MFMA_ " %[v_acc_5], acc[166:167], v[118:119], %[v_acc_5] \n" + " ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2] " + "\n" _UK_MFMA_ " %[v_acc_5], acc[168:169], v[120:121], %[v_acc_5] \n" + _UK_MFMA_ " %[v_acc_5], acc[170:171], v[122:123], %[v_acc_5] \n" + " buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n" + _UK_MFMA_ " %[v_acc_5], acc[172:173], v[124:125], %[v_acc_5] \n" + _UK_MFMA_ " %[v_acc_5], acc[174:175], v[126:127], %[v_acc_5] \n" + " ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3] " + "\n" _UK_MFMA_ " %[v_acc_6], acc[176:177], v[96:97], %[v_acc_6] \n" + _UK_MFMA_ " %[v_acc_6], acc[178:179], v[98:99], %[v_acc_6] \n" + " buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n" + _UK_MFMA_ " %[v_acc_6], acc[180:181], v[100:101], %[v_acc_6] \n" + _UK_MFMA_ " %[v_acc_6], acc[182:183], v[102:103], %[v_acc_6] \n" + " ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4] " + "\n" _UK_MFMA_ " %[v_acc_6], acc[184:185], v[104:105], %[v_acc_6] \n" + _UK_MFMA_ " %[v_acc_6], acc[186:187], v[106:107], %[v_acc_6] \n" + " buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n" + _UK_MFMA_ " %[v_acc_6], acc[188:189], v[108:109], %[v_acc_6] \n" + _UK_MFMA_ " %[v_acc_6], acc[190:191], v[110:111], %[v_acc_6] \n" + " ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5] " + "\n" _UK_MFMA_ " %[v_acc_7], acc[176:177], v[112:113], %[v_acc_7] \n" + _UK_MFMA_ " %[v_acc_7], acc[178:179], v[114:115], %[v_acc_7] \n" + " buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n" + _UK_MFMA_ " %[v_acc_7], acc[180:181], v[116:117], %[v_acc_7] \n" + _UK_MFMA_ " %[v_acc_7], acc[182:183], v[118:119], %[v_acc_7] \n" + " ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6] " + "\n" _UK_MFMA_ " %[v_acc_7], acc[184:185], v[120:121], %[v_acc_7] \n" + _UK_MFMA_ " %[v_acc_7], acc[186:187], v[122:123], %[v_acc_7] \n" + " buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n" + _UK_MFMA_ " %[v_acc_7], acc[188:189], v[124:125], %[v_acc_7] \n" + _UK_MFMA_ " %[v_acc_7], acc[190:191], v[126:127], %[v_acc_7] \n" + " ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7] \n" + " s_waitcnt vmcnt(32) \n" + _UK_MFMA_ " %[v_acc_8], acc[192:193], v[96:97], %[v_acc_8] \n" + _UK_MFMA_ " %[v_acc_8], acc[194:195], v[98:99], %[v_acc_8] \n" + " buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n" + _UK_MFMA_ " %[v_acc_8], acc[196:197], v[100:101], %[v_acc_8] \n" + _UK_MFMA_ " %[v_acc_8], acc[198:199], v[102:103], %[v_acc_8] \n" + _UK_MFMA_ " %[v_acc_8], acc[200:201], v[104:105], %[v_acc_8] \n" + _UK_MFMA_ " %[v_acc_8], acc[202:203], v[106:107], %[v_acc_8] \n" + " buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n" + _UK_MFMA_ " %[v_acc_8], acc[204:205], v[108:109], %[v_acc_8] \n" + _UK_MFMA_ " %[v_acc_8], acc[206:207], v[110:111], %[v_acc_8] \n" + _UK_MFMA_ " %[v_acc_9], acc[192:193], v[112:113], %[v_acc_9] \n" + _UK_MFMA_ " %[v_acc_9], acc[194:195], v[114:115], %[v_acc_9] \n" + " buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n" + _UK_MFMA_ " %[v_acc_9], acc[196:197], v[116:117], %[v_acc_9] \n" + _UK_MFMA_ " %[v_acc_9], acc[198:199], v[118:119], %[v_acc_9] \n" + _UK_MFMA_ " %[v_acc_9], acc[200:201], v[120:121], %[v_acc_9] \n" + _UK_MFMA_ " %[v_acc_9], acc[202:203], v[122:123], %[v_acc_9] \n" + " buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n" + _UK_MFMA_ " %[v_acc_9], acc[204:205], v[124:125], %[v_acc_9] \n" + _UK_MFMA_ " %[v_acc_9], acc[206:207], v[126:127], %[v_acc_9] \n" + _UK_MFMA_ " %[v_acc_10], acc[208:209], v[96:97], %[v_acc_10] \n" + _UK_MFMA_ " %[v_acc_10], acc[210:211], v[98:99], %[v_acc_10] \n" + " buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n" + _UK_MFMA_ " %[v_acc_10], acc[212:213], v[100:101], %[v_acc_10] \n" + _UK_MFMA_ " %[v_acc_10], acc[214:215], v[102:103], %[v_acc_10] \n" + _UK_MFMA_ " %[v_acc_10], acc[216:217], v[104:105], %[v_acc_10] \n" + _UK_MFMA_ " %[v_acc_10], acc[218:219], v[106:107], %[v_acc_10] \n" + " buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n" + _UK_MFMA_ " %[v_acc_10], acc[220:221], v[108:109], %[v_acc_10] \n" + _UK_MFMA_ " %[v_acc_10], acc[222:223], v[110:111], %[v_acc_10] \n" + _UK_MFMA_ " %[v_acc_11], acc[208:209], v[112:113], %[v_acc_11] \n" + _UK_MFMA_ " %[v_acc_11], acc[210:211], v[114:115], %[v_acc_11] \n" + " buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n" + _UK_MFMA_ " %[v_acc_11], acc[212:213], v[116:117], %[v_acc_11] \n" + _UK_MFMA_ " %[v_acc_11], acc[214:215], v[118:119], %[v_acc_11] \n" + _UK_MFMA_ " %[v_acc_11], acc[216:217], v[120:121], %[v_acc_11] \n" + _UK_MFMA_ " %[v_acc_11], acc[218:219], v[122:123], %[v_acc_11] \n" + " buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n" + _UK_MFMA_ " %[v_acc_11], acc[220:221], v[124:125], %[v_acc_11] \n" + _UK_MFMA_ " %[v_acc_11], acc[222:223], v[126:127], %[v_acc_11] \n" + " s_waitcnt vmcnt(32) \n" + _UK_MFMA_ " %[v_acc_12], acc[224:225], v[96:97], %[v_acc_12] \n" + _UK_MFMA_ " %[v_acc_12], acc[226:227], v[98:99], %[v_acc_12] \n" + " buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n" + _UK_MFMA_ " %[v_acc_12], acc[228:229], v[100:101], %[v_acc_12] \n" + _UK_MFMA_ " %[v_acc_12], acc[230:231], v[102:103], %[v_acc_12] \n" + _UK_MFMA_ " %[v_acc_12], acc[232:233], v[104:105], %[v_acc_12] \n" + _UK_MFMA_ " %[v_acc_12], acc[234:235], v[106:107], %[v_acc_12] \n" + " buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n" + _UK_MFMA_ " %[v_acc_12], acc[236:237], v[108:109], %[v_acc_12] \n" + _UK_MFMA_ " %[v_acc_12], acc[238:239], v[110:111], %[v_acc_12] \n" + _UK_MFMA_ " %[v_acc_13], acc[224:225], v[112:113], %[v_acc_13] \n" + _UK_MFMA_ " %[v_acc_13], acc[226:227], v[114:115], %[v_acc_13] \n" + " buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n" + _UK_MFMA_ " %[v_acc_13], acc[228:229], v[116:117], %[v_acc_13] \n" + _UK_MFMA_ " %[v_acc_13], acc[230:231], v[118:119], %[v_acc_13] \n" + _UK_MFMA_ " %[v_acc_13], acc[232:233], v[120:121], %[v_acc_13] \n" + _UK_MFMA_ " %[v_acc_13], acc[234:235], v[122:123], %[v_acc_13] \n" + " buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n" + _UK_MFMA_ " %[v_acc_13], acc[236:237], v[124:125], %[v_acc_13] \n" + _UK_MFMA_ " %[v_acc_13], acc[238:239], v[126:127], %[v_acc_13] \n" + _UK_MFMA_ " %[v_acc_14], acc[240:241], v[96:97], %[v_acc_14] \n" + _UK_MFMA_ " %[v_acc_14], acc[242:243], v[98:99], %[v_acc_14] \n" + " buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n" + _UK_MFMA_ " %[v_acc_14], acc[244:245], v[100:101], %[v_acc_14] \n" + _UK_MFMA_ " %[v_acc_14], acc[246:247], v[102:103], %[v_acc_14] \n" + _UK_MFMA_ " %[v_acc_14], acc[248:249], v[104:105], %[v_acc_14] \n" + _UK_MFMA_ " %[v_acc_14], acc[250:251], v[106:107], %[v_acc_14] \n" + " buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n" + _UK_MFMA_ " %[v_acc_14], acc[252:253], v[108:109], %[v_acc_14] \n" + _UK_MFMA_ " %[v_acc_14], acc[254:255], v[110:111], %[v_acc_14] \n" + _UK_MFMA_ " %[v_acc_15], acc[240:241], v[112:113], %[v_acc_15] \n" + _UK_MFMA_ " %[v_acc_15], acc[242:243], v[114:115], %[v_acc_15] \n" + " buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n" + _UK_MFMA_ " %[v_acc_15], acc[244:245], v[116:117], %[v_acc_15] \n" + _UK_MFMA_ " %[v_acc_15], acc[246:247], v[118:119], %[v_acc_15] \n" + _UK_MFMA_ " %[v_acc_15], acc[248:249], v[120:121], %[v_acc_15] \n" + _UK_MFMA_ " %[v_acc_15], acc[250:251], v[122:123], %[v_acc_15] \n" + " buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n" + _UK_MFMA_ " %[v_acc_15], acc[252:253], v[124:125], %[v_acc_15] \n" + _UK_MFMA_ " %[v_acc_15], acc[254:255], v[126:127], %[v_acc_15] \n" + + " s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n" + " s_cmp_gt_i32 %[s_loop_cnt] 0 \n" + " s_cbranch_scc0 L_end%= \n" + " s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n" + " s_cselect_b32 s86, %[s_tile_os_a], 0 \n" + " s_add_u32 s16, s86, s16 \n" + " s_addc_u32 s17, 0, s17 \n" + " s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n" + " s_cselect_b32 s86, %[s_tile_os_b], 0 \n" + " s_add_u32 s20, s86, s20 \n" + " s_addc_u32 s21, 0, s21 \n" + " s_branch L_start%= \n" + "L_end%=: \n" + " s_nop 2 \n" +#endif + +#undef _UK_MFMA_ + diff --git a/include/ck_tile/ops/flatmm_uk.hpp b/include/ck_tile/ops/flatmm_uk.hpp new file mode 100644 index 0000000000..8759f90066 --- /dev/null +++ b/include/ck_tile/ops/flatmm_uk.hpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" +#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp" +#include "ck_tile/ops/flatmm/block/flatmm_ff_32x512x128_1x4x1_16x16x32.hpp" +#include "ck_tile/ops/fused_moe/kernel/flatmm_uk_kernel.hpp" +#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp" +#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp" +#include "ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline_policy.hpp" +#include "ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/fused_moe/kernel/flatmm_uk_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/flatmm_uk_kernel.hpp new file mode 100644 index 0000000000..a7bade2747 --- /dev/null +++ b/include/ck_tile/ops/fused_moe/kernel/flatmm_uk_kernel.hpp @@ -0,0 +1,264 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/elementwise.hpp" +#include +#include + +// clang-format off +// [indexing implementation-1] +// using M_a as constexpr block_size to partition all tokens into different slices +// each slice map to one expert, and one expert can have multiple slices +// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5 +// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]] +// tok-0 tok-1 tok-2 tok-3 tok-4 +// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number) +// +// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]] +// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 +// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] +// +// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1) +// * this could be larger than actual, since actual tokens are on GPU +// +// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5] +// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -| +// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o] +// +// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr +// +// * Note on token_id_per_expert/sorted_token_ids_ptr data: +// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr. +// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from +// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr +// +// 32bit 0........23 24.....31 bit +// (data) -> (token_id | topk_id) +// low 24 bit is for token id, top 8 bit is for topk id +// +// the input after smooth-quant is [token, topk, hidden_dim], originally it is [token, hidden_dim] +// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim] +// +// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5] +// * length is (max_num_tokens_padded + block_size - 1) / block_size +// +// num_tokens_post_padded_ptr : [28] +// num_sorted_tiles_ptr : [7] +// +// * different from vLLM +// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id +// 2)need sorted_weight_ptr +// 3) use num_sorted_tiles_ptr, already divided by M_a +// +// * below used for indexing +// 1) sorted_token_ids_ptr [max_num_tokens_padded] +// 2) sorted_weight_ptr +// 3) sorted_expert_ids_ptr +// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one) +// +// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1) +// +// [indexing implementation-2] +// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]] +// tok-0 tok-1 tok-2 tok-3 tok-4 +// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number) +// +// we generate original rol/col id as +// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]] +// let x be one element of above, we can get: +// tpok_row_id(token_id) = x % num_tokens(5) +// tpok_col_id(expert_Id) = x / num_tokens +// topk_row_id/col_id can be used to access original topk_ids/topk_weight +// +// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]] +// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5 +// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]] +// +// we can get permuted_rc_ids: +// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]] +// +// +// clang-format on +// +namespace ck_tile { + +// m: num_tokens (or token*input-batch) +// k: intermediate_size +// n: intermediate_size used between 2 FC (TP slice this) +// e: num expert +// if doing pre-shuffle +// nr : n / Block_Nr +// kr : k / Block_Kr +// w : fattened 1d wave buffer +struct FlatmmUkHostArgs +{ + const void* a_ptr; // [m, k], input token + const void* b_ptr; // [m, k], input token + const void* c_ptr; // [m, k], output token + void* d_ptr; // [m, k], output token + void* dbg_int_ptr; // [m, k], output token + void* dbg_bf16_ptr; // [m, k], output token + void* dbg_fp32_ptr; // [m, k], output token + + index_t hidden_size; // K + index_t intermediate_size; // N + index_t num_tokens; // M + + index_t num_experts; // number of groups + index_t topk; // need this? + index_t stride_token; // for input/output, stride for each row, should >= hidden_size +}; + +// This is scatter/gather b2b group-gemm +template +struct FlatmmUkKernel +{ + using Pipeline = remove_cvref_t; + using Epilogue = remove_cvref_t; // TODO: not used + // static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu; + // static_assert(kBlockPerCu > 0); + + using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape + static constexpr index_t BlockSize_ = BlockShape::BlockSize; + + using ADataType = typename Pipeline::Problem::ADataType; + using GDataType = typename Pipeline::Problem::GDataType; + using DDataType = typename Pipeline::Problem::AccDataType; + using AccDataType = typename Pipeline::Problem::AccDataType; + using ODataType = typename Pipeline::Problem::ODataType; + using AScaleDataType = typename Pipeline::Problem::AScaleDataType; + using GScaleDataType = typename Pipeline::Problem::GScaleDataType; + using DScaleDataType = typename Pipeline::Problem::DScaleDataType; + using YSmoothScaleDataType = typename Pipeline::Problem::YSmoothScaleDataType; + using TopkWeightDataType = typename Pipeline::Problem::TopkWeightDataType; + using IndexDataType = typename Pipeline::Problem::IndexDataType; + using YDataType = typename Pipeline::Problem::YDataType; + + using Traits = typename Pipeline::Problem::Traits; + static constexpr bool UseUK = true; + + static constexpr bool IsGateOnly = Traits::IsGateOnly; + static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant; + static constexpr bool PadHiddenSize = Traits::PadHiddenSize; + static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize; + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + template <> struct t2s { static constexpr const char * name = "int8"; }; + // clang-format on + + CK_TILE_HOST static std::string GetName() + { +#define _SS_ std::string +#define _TS_ std::to_string + // clang-format off + using S_ = BlockShape; + + auto prec_str = [&] () { + std::string base_str = _SS_(t2s::name); + if (!std::is_same_v) { + base_str += _SS_("_") + _SS_(t2s::name); + } + return base_str; + }(); + + return _SS_("fused_moe_") + _SS_(prec_str) + "_" + + _TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" + + _TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" + + _TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name); +#undef _SS_ +#undef _TS_ + // clang-format on + } + + struct FusedMoeGemmKargs + { + const void* a_ptr; // [m, k], input token + const void* b_ptr; // [m, k], input token + const void* c_ptr; // [m, k], output token + void* d_ptr; // [m, k], output token + void* dbg_int_ptr; // [m, k], output token + void* dbg_bf16_ptr; // [m, k], output token + void* dbg_fp32_ptr; // [m, k], output token + + index_t hidden_size; // K + index_t intermediate_size; // N + index_t num_tokens; // M + + index_t num_experts; // number of groups + index_t topk; // need this? + index_t stride_token; // for input/output, stride for each row, should >= hidden_size + }; + + // TODO: switch karg based on + using Kargs = FusedMoeGemmKargs; + using Hargs = FlatmmUkHostArgs; + + CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) + { + // TODO: hargs/kargs not guranteed to be the same + return bit_cast(hargs); + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) + { + index_t ms = ck_tile::integer_divide_ceil(hargs.num_tokens, BlockShape::Block_M0); + index_t ns = ck_tile::integer_divide_ceil(hargs.intermediate_size, BlockShape::Block_N0); + return dim3(ns, ms, 1); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { +#if 0 + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[KERNEL] FlatmmUkKernel =====\n"); + printf("[KERNEL] blockDim: [%d, %d], gridDim: [%d, %d]\n", + static_cast(blockDim.x), + static_cast(blockDim.y), + static_cast(gridDim.x), + static_cast(gridDim.y)); + printf("[KERNEL] lds = %.3f (KB)\n", GetSmemSize() / 1024.0f); + } + + [[maybe_unused]] uint32_t tidx = threadIdx.x; // 0~255 + [[maybe_unused]] uint32_t tidy = threadIdx.y; // 0~0 + [[maybe_unused]] uint32_t bidx = blockIdx.x; // 0~1 + [[maybe_unused]] uint32_t bidy = blockIdx.y; // 0~51 + [[maybe_unused]] uint32_t bdmx = blockDim.x; // 256 + [[maybe_unused]] uint32_t bdmy = blockDim.y; // 1 + [[maybe_unused]] uint32_t gdmx = gridDim.x; // 2 + [[maybe_unused]] uint32_t gdmy = gridDim.y; // 52 + [[maybe_unused]] uint32_t gid = ((bdmx * bdmy) * gdmx) * bidy + + (bdmx * bdmy) * bidx + + bdmx * tidy + + tidx; + + [[maybe_unused]]int * dbg_int = static_cast(kargs.dbg_int_ptr); + [[maybe_unused]]short * dbg_bf16 = static_cast(kargs.dbg_bf16_ptr); + [[maybe_unused]]float * dbg_fp32 = static_cast(kargs.dbg_fp32_ptr); + + dbg_int[gid] = -1; + dbg_fp32[gid] = -1.0f; +#endif + + __shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()]; + + Pipeline{}(kargs, smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline.hpp b/include/ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline.hpp new file mode 100644 index 0000000000..336f4f3e94 --- /dev/null +++ b/include/ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline.hpp @@ -0,0 +1,337 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline_policy.hpp" + +namespace ck_tile { + +/* +This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight) +we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave) + + <----- gemm-N ------> + +----+----+----+----+ + | w0 | w1 | w2 | w3 | gemm-m + +----+----+----+----+ +*/ +template +struct GemmPipeline_FlatmmUk +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape + + using ADataType = typename Problem::ADataType; + using GDataType = typename Problem::GDataType; + using DDataType = typename Problem::AccDataType; + using AccDataType = typename Problem::AccDataType; + using ODataType = typename Problem::ODataType; + using AScaleDataType = typename Problem::AScaleDataType; + using GScaleDataType = typename Problem::GScaleDataType; + using DScaleDataType = typename Problem::DScaleDataType; + using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType; + using TopkWeightDataType = typename Problem::TopkWeightDataType; + using IndexDataType = typename Problem::IndexDataType; + using YDataType = typename Problem::YDataType; + + using Traits = typename Problem::Traits; + + static constexpr bool IsGateOnly = Traits::IsGateOnly; + static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant; + static constexpr bool PadHiddenSize = Traits::PadHiddenSize; + static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize; + + static constexpr index_t kAlignmentA = Policy::template GetAlignment_A(); + static constexpr index_t kAlignmentG = Policy::template GetAlignment_G(); + static constexpr index_t kAlignmentD = Policy::template GetAlignment_D(); + static constexpr index_t kAlignmentO = Policy::template GetAlignment_O(); + + static constexpr index_t SLD_A = static_cast(FusedMoeGemmPipelineSequencerEnum::SLD_A); + static constexpr index_t GLD_A = static_cast(FusedMoeGemmPipelineSequencerEnum::GLD_A); + static constexpr index_t GLD_B = static_cast(FusedMoeGemmPipelineSequencerEnum::GLD_B); + static constexpr index_t GST_O = static_cast(FusedMoeGemmPipelineSequencerEnum::GST_O); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::kBlockPerCu != -1) + return Problem::kBlockPerCu; + else + { + // minimize occupancy + return 2; + } + }(); + + static constexpr const char* name = "flatmm_uk"; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + constexpr index_t smem_0 = Policy::template GetUK_0().GetSmemSize(); + constexpr index_t smem_bridge = + BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); + return max(smem_0, smem_bridge); + } + + // this is the thread-offset along row/col + CK_TILE_HOST_DEVICE static auto GetACoord() + { + constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A(); + const auto a_coord = a_dist.calculate_index(); + return a_coord; + } + + // this is the thread-offset along row/col + CK_TILE_HOST_DEVICE static auto GetOCoord() + { + constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution(); + const auto o_coord = o_dist.calculate_index(); + return o_coord; + } + + CK_TILE_DEVICE constexpr auto GetNumRowCoords_A() + { + constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA; + constexpr index_t MLans = BlockShape::BlockSize / KLans; + constexpr index_t MRepeat = BlockShape::Block_M0 / MLans; + + return MRepeat; + } + + // TODO: properlly support scatter/gather + CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset) + { + constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA; + constexpr index_t MLans = BlockShape::BlockSize / KLans; + constexpr index_t MRepeat = BlockShape::Block_M0 / MLans; + + auto base_coord = threadIdx.x / KLans + base_offset; + + array coords; + static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; }); + + return coords; + } + CK_TILE_DEVICE auto GetRowCoords_O2(index_t base_offset) + { + constexpr index_t NLans = BlockShape::Block_N0 / kAlignmentO; + constexpr index_t MLans = BlockShape::BlockSize / NLans; + constexpr index_t MRepeat = BlockShape::Block_M0 / MLans; + + auto base_coord = threadIdx.x / NLans + base_offset; + + array coords; + static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; }); + + return coords; + } + + template + CK_TILE_DEVICE auto GetRowID(const ROW_COORDS coords, const IndexDataType* sorted_token_ids_ptr) + { + constexpr index_t n_size = coords.size(); + + array row_ids; + static_for<0, n_size, 1>{}([&](auto i) { + row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans; + }); + + return row_ids; + } + + template + CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords, + const TopkWeightDataType* sorted_weight_ptr) + { + constexpr index_t n_size = coords.size(); + + array w; + static_for<0, n_size, 1>{}([&](auto i) { + w.at(i) = sorted_weight_ptr[coords[i]]; // base_coord + i * MLans; + }); + + return w; + } + + // TODO: this row id is before shuffle atomic, need use acc distribution + CK_TILE_DEVICE auto GetRowCoords_O(index_t base_offset) + { + constexpr index_t MLanes = BlockShape::Warp_M1; + constexpr index_t Repeat_M = BlockShape::Repeat_M1; + + auto base_coord = threadIdx.x % MLanes + base_offset; + + array coords; + static_for<0, Repeat_M, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLanes; }); + + return coords; + } + + template + CK_TILE_DEVICE auto operator()(const Karg& kargs, CK_TILE_LDS_ADDR void* smem) + { +#if 0 + if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0) + { + printf("[PIPE] GemmPipeline_FlatmmUk =====\n"); + } + + [[maybe_unused]] uint32_t tidx = threadIdx.x; // 0~255 + [[maybe_unused]] uint32_t tidy = threadIdx.y; // 0~0 + [[maybe_unused]] uint32_t bidx = blockIdx.x; // 0~1 + [[maybe_unused]] uint32_t bidy = blockIdx.y; // 0~51 + [[maybe_unused]] uint32_t bdmx = blockDim.x; // 256 + [[maybe_unused]] uint32_t bdmy = blockDim.y; // 1 + [[maybe_unused]] uint32_t gdmx = gridDim.x; // 2 + [[maybe_unused]] uint32_t gdmy = gridDim.y; // 52 + [[maybe_unused]] uint32_t gid = ((bdmx * bdmy) * gdmx) * bidy + + (bdmx * bdmy) * bidx + + bdmx * tidy + + tidx; +#endif + [[maybe_unused]] int* dbg_int = static_cast(kargs.dbg_int_ptr); + [[maybe_unused]] short* dbg_bf16 = static_cast(kargs.dbg_bf16_ptr); + [[maybe_unused]] float* dbg_fp32 = static_cast(kargs.dbg_fp32_ptr); + + ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size; // N + index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W + index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W + index_t interm_idx_nr0 = __builtin_amdgcn_readfirstlane( + blockIdx.x * BlockShape::Block_Nr0); // intermediate_tile_id * Block_N / (N in W) + + // ---------------------------------------------------------------------------- + // a + auto a_res = + make_wave_buffer_resource(reinterpret_cast(kargs.a_ptr), + kargs.num_tokens * kargs.hidden_size * sizeof(ADataType)); + auto row_ids_a = GetRowCoords_A(blockIdx.y * BlockShape::Block_M0); + auto a_coords = generate_tuple( + [&](auto i) { + return row_ids_a[i] * kargs.hidden_size + + threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA; + }, + number{}); + + // ---------------------------------------------------------------------------- + // b + auto b_win = [&]() { + const GDataType* b_ptr = reinterpret_cast(kargs.b_ptr) + + interm_idx_nr0 * kr_0 * BlockShape::Block_W0; + auto b_view_ = make_naive_tensor_view( + b_ptr, + make_tuple(nr_0, kr_0, number{}), + make_tuple(kr_0 * BlockShape::Block_W0, number{}, 1), + number{}, + number<1>{}); + + auto b_window_ = make_tile_window_linear_raw( + b_view_, + make_tuple(number{}, + number{}, + number{}), + {0, 0, 0}, + Policy::template MakeGlobalTileDistribution_G(), + sequence<0, 1, 1>{}); + return b_window_; + }(); + auto b_res = b_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_; + auto b_coords = generate_tuple([&](auto i) { return b_win.cached_coords_[i].get_offset(); }, + number{}); + + // ---------------------------------------------------------------------------- + // core + auto uk_0 = Policy::template GetUK_0(); + auto acc_0 = uk_0(a_res, + a_coords, + b_res, + b_coords, + smem, + kargs.hidden_size, + BlockShape::Block_K0, // tile offset for B matrix each unroll + BlockShape::Block_Kr0 * + BlockShape::Block_W0, // tile offset for B matrix each unroll + dbg_int, + dbg_bf16, + dbg_fp32); + + // ---------------------------------------------------------------------------- + { + int tid = threadIdx.x; + float srdfp32 = 0.f; + float* smemfp32 = static_cast(smem); + + // ---------------------------------------------------------------------------- + // store to lds + for(uint32_t accIdx = 0; accIdx < 16; accIdx++) + { + float* accSmem = smemfp32 + 4 * blockDim.x * accIdx; + for(int xyzw = 0; xyzw < 4; xyzw++) + { + accSmem[tid * 4 + xyzw] = acc_0.get_thread_buffer()[accIdx * 4 + xyzw]; + } + } + block_sync_lds(); + + // ---------------------------------------------------------------------------- + // read from lds + int sldIdx = 0; + // int MLn = 15; + // int Nln = tid / MLn; + int tidInWave = tid % 64; + int waveId = tid / 64; + // sldIdx = (tid64 % 16 * 16 + tid64 / 16) % 64 + // + tid / 64; + sldIdx = (tidInWave % 16 * 16 + tidInWave / 16) + waveId * 4; + + const int accNLane = 16; + const int NLaneCnt = BlockShape::Block_N0 / 4; // xyzw 512 / 4 = 128 + const int accBlkSize = blockDim.x; + + int accInnerId = tid % accNLane; // 0~15 + int accNIdx = tid / NLaneCnt; // 0~127 = 0; 128~255 = 1 + int acc01BlkIdx = tid % NLaneCnt / 16; // 0 ~ 7 + int accBlkIdx = acc01BlkIdx * 2; // 0, 2, 4, ..., 14 + int acc4Id = accBlkIdx * accBlkSize // + + accNIdx * accBlkSize + accInnerId * 16; + sldIdx = acc4Id; + + float* d_buf = static_cast(kargs.d_ptr); + int c_blk_offset = blockIdx.y * BlockShape::Block_M0 * kargs.intermediate_size / 4 + + blockIdx.x * BlockShape::Block_N0 / 4; + + for(uint32_t accIdx = 0; accIdx < 16; accIdx++) + { + for(int xyzw = 0; xyzw < 4; xyzw++) + { + srdfp32 = smemfp32[accIdx * (1 * 4) + sldIdx * 4 + xyzw]; + acc_0.get_thread_buffer()[accIdx * 4 + xyzw] = srdfp32; + } + + // ---------------------------------------------------------------------------- + // store to vmem + int c_m_idx_offset = (accIdx + accNIdx * 16) * kargs.intermediate_size / 4; + int c_idx_offset = c_blk_offset + c_m_idx_offset + (tid % NLaneCnt); + + for(int xyzw = 0; xyzw < 4; xyzw++) + { + srdfp32 = acc_0.get_thread_buffer()[accIdx * 4 + xyzw]; + d_buf[c_idx_offset * 4 + xyzw] = srdfp32; + } + } + } + +#if 0 + // ---------------------------------------------------------------------------- + // debug + for(uint32_t dbgi = 0; dbgi < 64; dbgi++) + { + dbg_fp32[gid * 64 + dbgi] = acc_0.get_thread_buffer()[dbgi]; + } +#endif + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline_policy.hpp new file mode 100644 index 0000000000..ece045b1e5 --- /dev/null +++ b/include/ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline_policy.hpp @@ -0,0 +1,809 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" +#include "ck_tile/ops/flatmm_uk.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +namespace ck_tile { + +struct GemmPipelineFlatmmPolicy +{ + CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords() + { + // TODO: always 1 dword + return 1; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A() + { + // using async + constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords(); + constexpr index_t data_bytes = sizeof(typename Problem::ADataType); + static_assert(copy_bytes % data_bytes == 0); + return copy_bytes / data_bytes; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_G() + { + constexpr index_t copy_bytes = [&]() { return 16; }(); + constexpr index_t data_bytes = sizeof(typename Problem::GDataType); + static_assert(copy_bytes % data_bytes == 0); + return copy_bytes / data_bytes; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_D() + { + constexpr index_t copy_bytes = [&]() { return 16; }(); + constexpr index_t data_bytes = sizeof(typename Problem::DDataType); + static_assert(copy_bytes % data_bytes == 0); + return copy_bytes / data_bytes; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_O() + { + if constexpr(Problem::Traits::OAtomic == 1) + { + // pack fp16/bf16 atomic + static_assert(sizeof(typename Problem::ODataType) == 2); + return 2; + } + else if constexpr(Problem::Traits::OAtomic == 2) + { + // fp32 atomic + return 1; + } + else + { + return 16 / sizeof(typename Problem::ODataType); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack() + { + // TODO: this is for 3d layout + return 16 / sizeof(remove_cvref_t); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_A() + { + return GetSmemKPack(); + } + + // used for bridge LDS shuffle + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_Y() + { + // TODO: this should match mfma layout + return 16 / sizeof(typename Problem::YDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A() + { + constexpr auto a_sld_desc = MakeLdsLoadDesc_A(); + constexpr auto a_sst_desc = MakeLdsStoreDesc_A(); + static_assert(a_sld_desc.get_element_space_size() == a_sst_desc.get_element_space_size()); + return a_sld_desc.get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_Bridge() + { + constexpr auto bridge_sld_desc = MakeBridgeLdsLoadDesc(); + constexpr auto bridge_sst_desc = MakeBridgeLdsStoreDesc(); + static_assert(bridge_sld_desc.get_element_space_size() == + bridge_sst_desc.get_element_space_size()); + return bridge_sld_desc.get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + constexpr index_t a_lds = GetSmemSize_A(); + constexpr index_t bridge_lds = GetSmemSize_Bridge(); + return max(a_lds, bridge_lds); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK() + { + constexpr index_t K_vec = Alignment; + constexpr index_t K_rem = KPerBlock / K_vec; + + if constexpr(get_warp_size() < K_rem) + { + static_assert(K_rem % get_warp_size() == 0); + constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k + constexpr index_t K_wav = K_rem / get_warp_size(); + static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet"); + constexpr index_t M_wav = NumWarps / K_wav; + static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check"); + constexpr index_t M_rep = MPerBlock / M_wav; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<2>>, + tuple, sequence<1>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } + else + { + constexpr index_t K_lan = K_rem; + constexpr index_t M_lan = get_warp_size() / K_lan; + constexpr index_t M_wav = NumWarps; + static_assert(MPerBlock % (M_lan * M_wav) == 0, + "this tile size is too small please check"); + constexpr index_t M_rep = MPerBlock / (M_lan * M_wav); + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + } + + // optimized version for async, not same as simple MXK dist(pay attention!!) + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK_Async() + { + constexpr index_t K_vec = Alignment; + constexpr index_t K_rem = KPerBlock / K_vec; + + if constexpr(get_warp_size() <= K_rem) + { + static_assert(K_rem % get_warp_size() == 0); + constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k + constexpr index_t K_wav = K_rem / get_warp_size(); + static_assert(K_wav <= NumWarps, "do not support thread has repeat along K yet"); + constexpr index_t M_wav = NumWarps / K_wav; + static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check"); + constexpr index_t M_rep = MPerBlock / M_wav; + // NOTE: no swap, but hard to avoid LDS bank conflict + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<2>>, + tuple, sequence<1>>, + sequence<1, 2>, + sequence<0, 2>>{}); + } + else + { + constexpr index_t K_lan = K_rem; + constexpr index_t M_lan = get_warp_size() / K_lan; + constexpr index_t M_wav = NumWarps; + static_assert(MPerBlock % (M_lan * M_wav) == 0, + "this tile size is too small please check"); + constexpr index_t M_rep = MPerBlock / (M_lan * M_wav); + // NOTE: swapped for LDS load bank conflict free + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + // Note M_wave(num waves) is the fastest dim, different from sipmle 2d + // distribution + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_Nr_Kr_W() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence, + sequence>, + tuple, sequence<3>>, + tuple, sequence<0>>, + sequence<1, 2, 3>, + sequence<0, 0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_A() + { + constexpr index_t Block_M_ = Problem::BlockShape::Block_M0; + constexpr index_t Block_K_ = Problem::BlockShape::Block_K0; + constexpr index_t NumWarps_ = Problem::BlockShape::NumWarps; + constexpr index_t Alignment_ = GetAlignment_A(); + return MakeGlobalTileDistribution_SimpleMxK_Async(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G() + { + constexpr auto PermuteEnum = Problem::Traits::PermuteEnum; + // constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2; + using S_ = typename Problem::BlockShape; + if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) + { + // number{}.rrr(); + // number{}.eee(); + return MakeGlobalTileDistribution_Nr_Kr_W()>(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D() + { + constexpr auto PermuteEnum = Problem::Traits::PermuteEnum; + using S_ = typename Problem::BlockShape; + if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten) + { + return MakeGlobalTileDistribution_Nr_Kr_W()>(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O() + { + using S_ = remove_cvref_t; + using WarpGemm = remove_cvref_t())>; + // using CDataType = typename WarpGemm::CDataType; + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + return c_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A() + { + // A async->LDS + constexpr index_t Block_M = Problem::BlockShape::Block_M0; + constexpr index_t Block_K = Problem::BlockShape::Block_K0; + // constexpr index_t BlockSize = Problem::BlockShape::BlockSize; + constexpr index_t warpSize = ck_tile::get_warp_size(); + constexpr index_t NumWarps = Problem::BlockShape::NumWarps; + + constexpr index_t KPack = GetSmemKPack_A(); // LDS + constexpr index_t KVector = GetAlignment_A(); // async copy 1 dword + constexpr index_t KPad = KPack; // pad between warps + + static_assert(Block_K % KVector == 0); + constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K + if constexpr(LanesPerK >= warpSize) + { + // need multiple waves to load K + static_assert(LanesPerK % warpSize == 0); + constexpr index_t wavesPerK = LanesPerK / warpSize; + if constexpr(wavesPerK > NumWarps) + { + // TODO: need multiple issues along K to load all data + } + else + { + constexpr index_t wavesPerM = NumWarps / wavesPerK; + constexpr index_t NumIssues = Block_M / wavesPerM; + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number{}), // k2 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number<1>{}), // k2 + number{}, // lds store vector(actually no explicit store) + number<1>{}); + + constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple( + make_pass_through_transform(number{}), + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return lds_block_desc_issues_warps_lanes; + } + } + else + { + // lanes within a wave load different M but same K + static_assert(warpSize % LanesPerK == 0); + constexpr index_t LaneGroups = warpSize / LanesPerK; // along m + constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps); + + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // m2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // m2 + number{}, // k0 + number<1>{}), // k1 + number{}, // lds store vector(actually no explicit store) + number<1>{}); + + constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + return lds_block_desc_issues_warps_lanes; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A() + { + // A async->LDS + // Note that, this descriptor is only to construct the layout inside LDS + // in real Gemm pipeline, ds_read may not follow this pattern + // (may follow that in tile_distribution) + // below code is almost the same as SmemStore dist, with difference: + // 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc + // 2). return discriptor is in NxK 2d layout + constexpr index_t Block_M = Problem::BlockShape::Block_M0; + constexpr index_t Block_K = Problem::BlockShape::Block_K0; + // constexpr index_t BlockSize = Problem::BlockShape::BlockSize; + constexpr index_t warpSize = ck_tile::get_warp_size(); + constexpr index_t NumWarps = Problem::BlockShape::NumWarps; + + constexpr index_t KPack = GetSmemKPack_A(); // LDS + constexpr index_t KVector = GetAlignment_A(); // async copy 1 dword + constexpr index_t KPad = KPack; // pad between warps + + static_assert(Block_K % KVector == 0); + constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K + if constexpr(LanesPerK >= warpSize) + { + // need multiple waves to load K + static_assert(LanesPerK % warpSize == 0); + constexpr index_t wavesPerK = LanesPerK / warpSize; + if constexpr(wavesPerK >= NumWarps) + { + // TODO: need multiple issues along K to load all data + } + else + { + constexpr index_t wavesPerM = NumWarps / wavesPerK; + constexpr index_t NumIssues = Block_M / wavesPerM; + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number{}), // k2 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // k0 + number{}, // k1 + number<1>{}), // k2 + number{}, // lds load vector + number<1>{}); + + constexpr auto lds_desc_m_k = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return lds_desc_m_k; + } + } + else + { + // lanes within a wave load different M but same K + static_assert(warpSize % LanesPerK == 0); + constexpr index_t LaneGroups = warpSize / LanesPerK; // along m + constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps); + + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // m2 + number{}, // k0 + number{}), // k1 + make_tuple(number{}, // m0 + number{}, // m1 + number{}, // m2 + number{}, // k0 + number<1>{}), // k1 + number{}, // lds load vector + number<1>{}); + + constexpr auto lds_desc_m_k = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return lds_desc_m_k; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsLoadDesc() + { + constexpr index_t Block_M = Problem::BlockShape::Block_M0; + constexpr index_t Block_N = Problem::BlockShape::Block_N0; + + constexpr index_t KVector = GetSmemKPack_Y(); // async copy 1 dword + constexpr index_t KPad = 0; // pad between warps + + constexpr auto desc = + make_naive_tensor_descriptor(make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + return desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreDesc() + { + constexpr index_t Block_M = Problem::BlockShape::Block_M0; + constexpr index_t Block_N = Problem::BlockShape::Block_N0; + + constexpr index_t KVector = GetSmemKPack_Y(); // async copy 1 dword + constexpr index_t KPad = 0; // KVector; // pad between warps + + constexpr auto desc = + make_naive_tensor_descriptor(make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + return desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreForUKDesc() + { + constexpr index_t WarpPerBlock_N = Problem::BlockShape::WarpPerBlock_N0; + constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N0; + constexpr index_t Repeat_M = Problem::BlockShape::Repeat_M0; + + constexpr index_t kAMLane = 16; + constexpr index_t kABKLane = 4; + constexpr index_t kABKPerLane = 4; + + constexpr index_t KPack = kABKPerLane; + + constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, // m + number{}, // n + number{}, // n + number{}, // n + number{}, // m + number{}), // n + make_tuple(number{}, // m + number{}, // n + number{}, // n + number{}, // n + number{}, // m + number<1>{}), // n + number{}, // lds store vector(actually no explicit store) + number<1>{}); + + constexpr auto desc = transform_tensor_descriptor( + lds_block_desc_0, + make_tuple(make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, + number{}, + number{}, + number{}))), + make_tuple(sequence<0, 4>{}, sequence<1, 2, 3, 5>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm0() + { + using S_ = typename Problem::BlockShape; + // A is vgpr, B is agpr. But since we transposed, so also need swap this + // TODO: this is ugly + constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv; + // TODO: ugly + if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16) + { + return WarpGemmImpl, + 2>>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32) + { + return WarpGemmImpl, + 2>>{}; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_0() + { + // this function return seq<...> used to identify gld/sld/valu... inside mfma sequence + // the purpose is to hide thoes instructions under mfma + // every value inside seq<...> is a mask, indicating a specific operation + using S_ = typename Problem::BlockShape; + constexpr index_t SLD_A = static_cast(FusedMoeGemmPipelineSequencerEnum::SLD_A); + constexpr index_t GLD_A = static_cast(FusedMoeGemmPipelineSequencerEnum::GLD_A); + constexpr index_t GLD_B = static_cast(FusedMoeGemmPipelineSequencerEnum::GLD_B); + if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 && + S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 && + S_::Block_N1 == 128) + { + // Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async + // gld_a 8x ds_read_b128 sld_a total 64 slot :) + // clang-format off + constexpr auto seq_all = + // 0 1 2 3 4 5 6 7 + sequence{}; // 7 + return seq_all; + // clang-format on + } + else if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 && + S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 && + S_::Block_N1 == 128) + { + // Total 32 instructions, 16 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async + // gld_a 8x ds_read_b128 sld_a total 64 slot :) + // clang-format off + constexpr auto seq_all = + // 0 1 2 3 4 5 6 7 + sequence{}; // 3 + return seq_all; + // clang-format on + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_1() + { + // this function return seq<...> used to identify gld/sld/valu... inside mfma sequence + // the purpose is to hide thoes instructions under mfma + // every value inside seq<...> is a mask, indicating a specific operation + using S_ = typename Problem::BlockShape; + constexpr index_t GLD_B = static_cast(FusedMoeGemmPipelineSequencerEnum::GLD_B); + constexpr index_t GST_O = static_cast(FusedMoeGemmPipelineSequencerEnum::GST_O); + if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 && + S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 && + S_::Block_N1 == 128) + { + // Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async + // gld_a 8x ds_read_b128 sld_a total 64 slot :) + // clang-format off + constexpr auto seq_all = + // 0 1 2 3 4 5 6 7 + sequence{}; // 7 + return seq_all; + // clang-format on + } + else if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 && + S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 && + S_::Block_N1 == 128) + { + // Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async + // gld_a 8x ds_read_b128 sld_a total 64 slot :) + // clang-format off + constexpr auto seq_all = + // 0 1 2 3 4 5 6 7 + sequence{}; // 3 + return seq_all; + // clang-format on + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1() + { + using S_ = typename Problem::BlockShape; + constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv; + // TODO: ugly + if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16) + { + return WarpGemmImpl, + 2>>{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32) + { + return WarpGemmImpl, + 2>>{}; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm0() + { + using S_ = remove_cvref_t; + using WarpGemm = remove_cvref_t())>; + using CDataType = typename WarpGemm::CDataType; + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm1() + { + using S_ = remove_cvref_t; + using WarpGemm = remove_cvref_t())>; + using CDataType = typename WarpGemm::CDataType; + + constexpr auto c_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + return c_block_tensor; + } + + // this is used as A matrix for 2nd gemm + template + CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution() + { + using S_ = remove_cvref_t; + using WarpGemm = remove_cvref_t())>; + + // TODO: all waves a along different N, but same M + constexpr auto y_outer_dstr_enc = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{}); + constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode); + return y_block_dstr; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeYBlockTile() + { + constexpr auto y_block_dstr = MakeYTileDistribution(); + auto y_block_tensor = + make_static_distributed_tensor(y_block_dstr); + return y_block_tensor; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetUK_0() + { + using S_ = typename Problem::BlockShape; + if constexpr(std::is_same_v && + std::is_same_v && + S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 && + S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32) + { + return Flatmm_ff_32x512x128_1x4x1_16x16x32_BF16{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 && + S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32) + { + return Flatmm_ff_32x512x128_1x4x1_16x16x32_FP16{}; + } + } +}; +} // namespace ck_tile