mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
remove debug 9.8 tflops
This commit is contained in:
19
example/ck_tile/18_flatmm_uk/CMakeLists.txt
Normal file
19
example/ck_tile/18_flatmm_uk/CMakeLists.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
set(TILE_EXAPMLE_FLATMM_UK "tile_example_flatmm_uk")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding ${TILE_EXAPMLE_FLATMM_UK}")
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_executable(${TILE_EXAPMLE_FLATMM_UK} EXCLUDE_FROM_ALL main.cpp)
|
||||
target_include_directories(${TILE_EXAPMLE_FLATMM_UK} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${TILE_EXAPMLE_FLATMM_UK} PRIVATE ${INSTANCE_SRCS})
|
||||
|
||||
set(TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
list(APPEND TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a
|
||||
list(APPEND TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4) # rta
|
||||
# list(APPEND TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1)
|
||||
# list(APPEND TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
|
||||
target_compile_options(${TILE_EXAPMLE_FLATMM_UK} PRIVATE ${TILE_EXAPMLE_FLATMM_UK_COMPILE_OPTIONS})
|
||||
100
example/ck_tile/18_flatmm_uk/flatmm_uk.hpp
Normal file
100
example/ck_tile/18_flatmm_uk/flatmm_uk.hpp
Normal file
@@ -0,0 +1,100 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/flatmm_uk.hpp"
|
||||
#include <string>
|
||||
|
||||
// this is only a convenient structure for creating an example
|
||||
// this is not part of the host API
|
||||
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FlatmmUkTypeConfig;
|
||||
|
||||
template <typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FlatmmUkTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using GDataType = ck_tile::bf16_t;
|
||||
using DDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
};
|
||||
|
||||
template <typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FlatmmUkTypeConfig<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, ST, SW, SQ, KW>
|
||||
{
|
||||
using ADataType = ck_tile::fp16_t;
|
||||
using GDataType = ck_tile::fp16_t;
|
||||
using DDataType = ck_tile::fp16_t;
|
||||
using AccDataType = float;
|
||||
using ODataType = ck_tile::fp16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
};
|
||||
|
||||
template <typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FlatmmUkTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>
|
||||
{
|
||||
using ADataType = ck_tile::int8_t;
|
||||
using GDataType = ck_tile::int8_t;
|
||||
using DDataType = ck_tile::int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
};
|
||||
|
||||
|
||||
struct flatmm_uk_args
|
||||
{
|
||||
const void* a_ptr; // [m, k], input token
|
||||
const void* b_ptr; // [m, k], input token
|
||||
const void* c_ptr; // [m, k], output token (no need to do zeroing)
|
||||
void* d_ptr; // [m, k], output token (no need to do zeroing)
|
||||
void* dbg_int_ptr; // [m, k], output token (no need to do zeroing)
|
||||
void* dbg_bf16_ptr; // [m, k], output token (no need to do zeroing)
|
||||
void* dbg_fp32_ptr; // [m, k], output token (no need to do zeroing)
|
||||
|
||||
ck_tile::index_t block_m; // block_m, used to devide the input
|
||||
ck_tile::index_t hidden_size; // k
|
||||
ck_tile::index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
|
||||
ck_tile::index_t num_tokens; // input number of tokens for current iteration
|
||||
ck_tile::index_t num_experts; // number of groups
|
||||
ck_tile::index_t topk; // need this?
|
||||
|
||||
ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size
|
||||
};
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct flatmm_uk_traits
|
||||
{
|
||||
std::string prec_i; // input precision
|
||||
std::string prec_w; // weight precision
|
||||
std::string prec_o; // output precision
|
||||
std::string prec_st; // token scale data type
|
||||
std::string prec_sw; // weight scale data type
|
||||
std::string prec_sq; // smooth quant scale
|
||||
std::string prec_kw; // topk-weight data type
|
||||
int block_m;
|
||||
int gate_only;
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
float flatmm_uk(flatmm_uk_traits, flatmm_uk_args, const ck_tile::stream_config&);
|
||||
192
example/ck_tile/18_flatmm_uk/instances/flatmm_uk_api.cpp
Normal file
192
example/ck_tile/18_flatmm_uk/instances/flatmm_uk_api.cpp
Normal file
@@ -0,0 +1,192 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "flatmm_uk.hpp"
|
||||
#include "flatmm_uk_api.hpp"
|
||||
#include "ck_tile/ops/flatmm_uk.hpp"
|
||||
#include <iostream>
|
||||
|
||||
template <ck_tile::index_t... Is>
|
||||
using S = ck_tile::sequence<Is...>;
|
||||
|
||||
// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j
|
||||
template <typename Ts_>
|
||||
float flatmm_uk_(const ck_tile::stream_config& s_, flatmm_uk_args_ a_)
|
||||
{
|
||||
printf("[FF] ======= fused_moegemm_() ======= \n \tget moe arg in a_ <flatmm_uk_args>, get "
|
||||
"config in Ts_\n");
|
||||
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
|
||||
using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0,
|
||||
typename Ts_::BlockTile_1,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0>;
|
||||
printf("[FF] --- fused_moegemm_(): <FusedMoeGemmShape> --- \n");
|
||||
printf("[FF] f_shape::BlockSize = %d\n", static_cast<uint32_t>(f_shape::BlockSize));
|
||||
printf("[FF] f_shape::NumWarps = %d\n", static_cast<uint32_t>(f_shape::NumWarps));
|
||||
printf("[FF] --------- \n");
|
||||
printf("[FF] f_shape::Block_M0 = %d\n", static_cast<uint32_t>(f_shape::Block_M0));
|
||||
printf("[FF] f_shape::Block_N0 = %d\n", static_cast<uint32_t>(f_shape::Block_N0));
|
||||
printf("[FF] f_shape::Block_K0 = %d\n", static_cast<uint32_t>(f_shape::Block_K0));
|
||||
printf("[FF] f_shape::WarpPerBlock_M0 = %d\n", static_cast<uint32_t>(f_shape::WarpPerBlock_M0));
|
||||
printf("[FF] f_shape::WarpPerBlock_N0 = %d\n", static_cast<uint32_t>(f_shape::WarpPerBlock_N0));
|
||||
printf("[FF] f_shape::WarpPerBlock_K0 = %d\n", static_cast<uint32_t>(f_shape::WarpPerBlock_K0));
|
||||
printf("[FF] f_shape::Warp_M0 = %d\n", static_cast<uint32_t>(f_shape::Warp_M0));
|
||||
printf("[FF] f_shape::Warp_N0 = %d\n", static_cast<uint32_t>(f_shape::Warp_N0));
|
||||
printf("[FF] f_shape::Warp_K0 = %d\n", static_cast<uint32_t>(f_shape::Warp_K0));
|
||||
printf("[FF] f_shape::ThreadPerBlock_M0 = %d\n",
|
||||
static_cast<uint32_t>(f_shape::ThreadPerBlock_M0));
|
||||
printf("[FF] f_shape::ThreadPerBlock_N0 = %d\n",
|
||||
static_cast<uint32_t>(f_shape::ThreadPerBlock_N0));
|
||||
printf("[FF] f_shape::ThreadPerBlock_K0 = %d\n",
|
||||
static_cast<uint32_t>(f_shape::ThreadPerBlock_K0));
|
||||
printf("[FF] f_shape::Repeat_M0 = %d\n", static_cast<uint32_t>(f_shape::Repeat_M0));
|
||||
printf("[FF] f_shape::Repeat_N0 = %d\n", static_cast<uint32_t>(f_shape::Repeat_N0));
|
||||
printf("[FF] f_shape::Repeat_K0 = %d\n", static_cast<uint32_t>(f_shape::Repeat_K0));
|
||||
printf("[FF] f_shape::Block_W0 = %d\n", static_cast<uint32_t>(f_shape::Block_W0));
|
||||
printf("[FF] f_shape::Block_Nr0 = %d\n", static_cast<uint32_t>(f_shape::Block_Nr0));
|
||||
printf("[FF] f_shape::Block_Kr0 = %d\n", static_cast<uint32_t>(f_shape::Block_Kr0));
|
||||
printf("[FF] --------- \n");
|
||||
printf("[FF] f_shape::Block_M1 = %d\n", static_cast<uint32_t>(f_shape::Block_M1));
|
||||
printf("[FF] f_shape::Block_N1 = %d\n", static_cast<uint32_t>(f_shape::Block_N1));
|
||||
printf("[FF] f_shape::Block_K1 = %d\n", static_cast<uint32_t>(f_shape::Block_K1));
|
||||
printf("[FF] f_shape::WarpPerBlock_M1 = %d\n", static_cast<uint32_t>(f_shape::WarpPerBlock_M1));
|
||||
printf("[FF] f_shape::WarpPerBlock_N1 = %d\n", static_cast<uint32_t>(f_shape::WarpPerBlock_N1));
|
||||
printf("[FF] f_shape::WarpPerBlock_K1 = %d\n", static_cast<uint32_t>(f_shape::WarpPerBlock_K1));
|
||||
printf("[FF] f_shape::Warp_M1 = %d\n", static_cast<uint32_t>(f_shape::Warp_M1));
|
||||
printf("[FF] f_shape::Warp_N1 = %d\n", static_cast<uint32_t>(f_shape::Warp_N1));
|
||||
printf("[FF] f_shape::Warp_K1 = %d\n", static_cast<uint32_t>(f_shape::Warp_K1));
|
||||
printf("[FF] f_shape::ThreadPerBlock_M1 = %d\n",
|
||||
static_cast<uint32_t>(f_shape::ThreadPerBlock_M1));
|
||||
printf("[FF] f_shape::ThreadPerBlock_N1 = %d\n",
|
||||
static_cast<uint32_t>(f_shape::ThreadPerBlock_N1));
|
||||
printf("[FF] f_shape::ThreadPerBlock_K1 = %d\n",
|
||||
static_cast<uint32_t>(f_shape::ThreadPerBlock_K1));
|
||||
printf("[FF] f_shape::Repeat_M1 = %d\n", static_cast<uint32_t>(f_shape::Repeat_M1));
|
||||
printf("[FF] f_shape::Repeat_N1 = %d\n", static_cast<uint32_t>(f_shape::Repeat_N1));
|
||||
printf("[FF] f_shape::Repeat_K1 = %d\n", static_cast<uint32_t>(f_shape::Repeat_K1));
|
||||
printf("[FF] f_shape::Block_W1 = %d\n", static_cast<uint32_t>(f_shape::Block_W1));
|
||||
printf("[FF] f_shape::Block_Nr1 = %d\n", static_cast<uint32_t>(f_shape::Block_Nr1));
|
||||
printf("[FF] f_shape::Block_Kr1 = %d\n", static_cast<uint32_t>(f_shape::Block_Kr1));
|
||||
using f_problem =
|
||||
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
|
||||
typename Ts_::GDataType,
|
||||
typename Ts_::DDataType,
|
||||
typename Ts_::AccDataType,
|
||||
typename Ts_::ODataType,
|
||||
typename Ts_::AScaleDataType,
|
||||
typename Ts_::GScaleDataType,
|
||||
typename Ts_::DScaleDataType,
|
||||
typename Ts_::YSmoothScaleDataType,
|
||||
typename Ts_::TopkWeightDataType,
|
||||
typename Ts_::IndexDataType,
|
||||
ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded
|
||||
f_shape,
|
||||
f_traits>;
|
||||
|
||||
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
|
||||
using f_pipeline = ck_tile::GemmPipeline_FlatmmUk<f_problem>;
|
||||
using f_kernel = ck_tile::FlatmmUkKernel<f_pipeline, void>;
|
||||
|
||||
const dim3 grids = f_kernel::GridSize(a_);
|
||||
constexpr dim3 blocks = f_kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
printf("[FF] grids = [%d, %d, %d]\n", grids.x, grids.y, grids.z);
|
||||
printf("[FF] blocks = [%d, %d, %d]\n", blocks.x, blocks.y, blocks.z);
|
||||
|
||||
static int printed = 0;
|
||||
|
||||
auto kargs = f_kernel::MakeKargs(a_);
|
||||
f_kernel kernel{};
|
||||
auto lambda_kenrel =
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(kernel, grids, blocks, 0, kargs);
|
||||
|
||||
if(s_.log_level_ > 0 && printed == 10)
|
||||
{
|
||||
// std::cout << ", " << f_kernel::GetName() << std::flush;
|
||||
printed = 1;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s_, lambda_kenrel
|
||||
// ck_tile::make_kernel<blocks.x, kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs)
|
||||
);
|
||||
}
|
||||
|
||||
float flatmm_uk(flatmm_uk_traits t, flatmm_uk_args a, const ck_tile::stream_config& s)
|
||||
{
|
||||
// auto s_ = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1};
|
||||
auto s_ = s;
|
||||
|
||||
auto t_ = flatmm_uk_traits_{t.prec_i,
|
||||
t.prec_w,
|
||||
t.prec_o,
|
||||
t.prec_st,
|
||||
t.prec_sw,
|
||||
t.prec_sq,
|
||||
t.prec_kw,
|
||||
t.block_m,
|
||||
t.gate_only,
|
||||
t.fused_quant};
|
||||
auto a_ = flatmm_uk_args_{
|
||||
a.a_ptr, // const void* a_ptr;
|
||||
a.b_ptr, // const void* a_ptr;
|
||||
a.c_ptr, // void* o_ptr;
|
||||
a.d_ptr, // void* o_ptr;
|
||||
a.dbg_int_ptr,
|
||||
a.dbg_bf16_ptr,
|
||||
a.dbg_fp32_ptr,
|
||||
a.hidden_size, // index_t hidden_size;
|
||||
a.intermediate_size, // index_t intermediate_size;
|
||||
a.num_tokens, // index_t num_tokens;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
a.stride_token // index_t stride_token;
|
||||
};
|
||||
|
||||
float r = -1;
|
||||
|
||||
if(t_.prec_i == "bf16" && t_.prec_w == "bf16" && t_.prec_o == "bf16" && t_.prec_st == "fp32" &&
|
||||
t_.prec_sw == "fp32" && t_.prec_sq == "fp32" && t_.prec_kw == "fp32" && t_.block_m == 32 &&
|
||||
t_.gate_only == 1)
|
||||
{
|
||||
using t_ = fmoe_<ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
ck_tile::bf16_t,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
S<32, 512, 128, 128>,
|
||||
S<1, 4, 1>,
|
||||
S<16, 16, 32>,
|
||||
1,
|
||||
0>;
|
||||
r = flatmm_uk_<t_>(s_, a_);
|
||||
}
|
||||
else if(t_.prec_i == "fp16" && t_.prec_w == "fp16" && t_.prec_o == "fp16" &&
|
||||
t_.prec_st == "fp32" && t_.prec_sw == "fp32" && t_.prec_sq == "fp32" &&
|
||||
t_.prec_kw == "fp32" && t_.block_m == 32 && t_.gate_only == 1)
|
||||
{
|
||||
using t_ = fmoe_<ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
ck_tile::fp16_t,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
S<32, 512, 128, 128>,
|
||||
S<1, 4, 1>,
|
||||
S<16, 16, 32>,
|
||||
1,
|
||||
0>;
|
||||
r = flatmm_uk_<t_>(s_, a_);
|
||||
}
|
||||
|
||||
// keep unsupported case return negative
|
||||
if(r < 0)
|
||||
return -1;
|
||||
|
||||
return r;
|
||||
}
|
||||
76
example/ck_tile/18_flatmm_uk/instances/flatmm_uk_api.hpp
Normal file
76
example/ck_tile/18_flatmm_uk/instances/flatmm_uk_api.hpp
Normal file
@@ -0,0 +1,76 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/flatmm_uk.hpp"
|
||||
#include <string>
|
||||
|
||||
// runtime args
|
||||
struct flatmm_uk_args_ : public ck_tile::FlatmmUkHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct flatmm_uk_traits_
|
||||
{
|
||||
std::string prec_i; // input precision
|
||||
std::string prec_w; // weight precision
|
||||
std::string prec_o; // output precision
|
||||
std::string prec_st; // token scale data type
|
||||
std::string prec_sw; // weight scale data type
|
||||
std::string prec_sq; // smooth quant scale
|
||||
std::string prec_kw; // topk-weight data type
|
||||
int block_m;
|
||||
int gate_only;
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename I,
|
||||
typename W,
|
||||
typename O,
|
||||
typename ST,
|
||||
typename SW,
|
||||
typename SQ,
|
||||
typename KW,
|
||||
typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down>
|
||||
typename WarpPerBlock_,
|
||||
typename WarpTile_, // seq<*,*,*>, used to select mfma
|
||||
ck_tile::index_t GateOnly_ = 0,
|
||||
ck_tile::index_t FusedQuant_ = 0>
|
||||
struct fmoe_ // traits, ugly name, only used for internal
|
||||
{
|
||||
using TypeConfig = FlatmmUkTypeConfig<I, W, O, ST, SW, SQ, KW>;
|
||||
|
||||
using ADataType = ck_tile::remove_cvref_t<typename TypeConfig::ADataType>;
|
||||
using GDataType = ck_tile::remove_cvref_t<typename TypeConfig::GDataType>;
|
||||
using DDataType = ck_tile::remove_cvref_t<typename TypeConfig::DDataType>;
|
||||
using AccDataType = ck_tile::remove_cvref_t<typename TypeConfig::AccDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename TypeConfig::ODataType>;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::AScaleDataType>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::GScaleDataType>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::DScaleDataType>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
|
||||
using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>;
|
||||
|
||||
static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token
|
||||
static constexpr ck_tile::index_t BI_ =
|
||||
BlockTIle_::at(ck_tile::number<1>{}); // block intermediate
|
||||
static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden
|
||||
static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down
|
||||
|
||||
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>;
|
||||
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>;
|
||||
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
static constexpr ck_tile::index_t GateOnly = GateOnly_;
|
||||
static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
|
||||
};
|
||||
692
example/ck_tile/18_flatmm_uk/main.cpp
Normal file
692
example/ck_tile/18_flatmm_uk/main.cpp
Normal file
@@ -0,0 +1,692 @@
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "flatmm_uk.hpp"
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
CK_TILE_HOST void my_reference_gemm(const ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
const ck_tile::HostTensor<BDataType>& b_k_n,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n,
|
||||
float t,
|
||||
const AElementOp& a_element_op = {},
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(0);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
printf("[REF] M = %zu, N = %zu, K = %zu\n", M, N, K);
|
||||
|
||||
auto cal_tflops = [&](auto ms) {
|
||||
double flop_gemm = 2.0 * M * N * K;
|
||||
return (flop_gemm) / (static_cast<double>(ms) * 1e-3) / 1e12;
|
||||
};
|
||||
|
||||
auto cal_tbps = [&](auto ms) {
|
||||
double a_bytes = static_cast<double>(M) * K * sizeof(ADataType);
|
||||
double b_bytes = static_cast<double>(N) * K * sizeof(BDataType);
|
||||
double o_bytes = static_cast<double>(M) * N * sizeof(CDataType);
|
||||
|
||||
return (a_bytes + b_bytes + o_bytes) / (static_cast<double>(ms) * 1e-3) / 1e12;
|
||||
};
|
||||
|
||||
std::cout << ", " << t * 1.E3 << " us, " << cal_tflops(t) << " tflops, " << cal_tbps(t)
|
||||
<< " TB/s" << std::endl
|
||||
<< std::flush;
|
||||
|
||||
auto f_mn = [&](auto m, auto n) {
|
||||
AccDataType v_acc = 0;
|
||||
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
ADataType v_a = a_element_op(a_m_k(m, k));
|
||||
BDataType v_b = b_element_op(b_k_n(n, k));
|
||||
|
||||
v_acc +=
|
||||
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
};
|
||||
|
||||
ck_tile::make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
// mfma_type, 0:32x32, 1:16x16
|
||||
// TODO: padding?
|
||||
template <typename T>
|
||||
auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0)
|
||||
{
|
||||
assert(t.get_lengths().size() == 3);
|
||||
int b_ = t.get_lengths()[0];
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[2];
|
||||
if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 16, 2, 8});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 32, 4, 8});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 32, 2, 16});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 64, 4, 16});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
return t;
|
||||
}
|
||||
template <typename T>
|
||||
auto shuffle_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[0];
|
||||
int k_ = t.get_lengths()[1];
|
||||
if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({n_ / 32, 32, k_ / 16, 2, 8});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({n_ / 16, 16, k_ / 32, 4, 8});
|
||||
printf("[FF] permute: n_ = %d, k_ = %d, n_/16 = %d, k_/32 = %d\n", n_, k_, n_ / 16, k_ / 32);
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({n_ / 32, 32, k_ / 32, 2, 16});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({n_ / 16, 16, k_ / 64, 4, 16});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "64", "num of m")
|
||||
.insert("n", "1024", "num of n")
|
||||
.insert("k", "8192", "num of k")
|
||||
.insert("t", "64", "num input tokens")
|
||||
.insert("e", "8", "num of experts")
|
||||
.insert("tk", "1", "topk")
|
||||
.insert("h", "4096", "hidden_size of this model")
|
||||
.insert("i", "4096", "intermediate_size between 2 gemms of FFN")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
|
||||
.insert("bm", "32", "blocking factor for sorted tokens")
|
||||
.insert("tp", "8", "tensor parallel size")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec_i", "bf16", "input precision")
|
||||
.insert("prec_w", "bf16", "weight precision")
|
||||
.insert("prec_o", "bf16", "output precision")
|
||||
.insert("prec_st", "auto", "token scale data type. auto will set to fp32")
|
||||
.insert("prec_sw", "auto", "weight scale data type. auto will set to fp32")
|
||||
.insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32")
|
||||
.insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32")
|
||||
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
||||
.insert(
|
||||
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
|
||||
.insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm")
|
||||
.insert("balance",
|
||||
"0",
|
||||
"if set to 1, will try balance the expert in topk-ids(convenient for testing)")
|
||||
.insert("init",
|
||||
"2",
|
||||
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
|
||||
"normalized(slow)")
|
||||
.insert("seed", "11939", "seed used to do random")
|
||||
.insert("warmup", "1", "cold iter")
|
||||
.insert("repeat", "4", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type,
|
||||
// SQ:smooth-quant-type, KW:topk-weight-type
|
||||
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
printf("[FF] M = %d, N = %d, K = %d\n", M, N, K);
|
||||
|
||||
ck_tile::index_t experts = arg_parser.get_int("e");
|
||||
ck_tile::index_t topk = arg_parser.get_int("tk");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
ck_tile::index_t block_m = arg_parser.get_int("bm");
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_w = arg_parser.get_str("prec_w");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_st = arg_parser.get_str("prec_st");
|
||||
std::string prec_sw = arg_parser.get_str("prec_sw");
|
||||
std::string prec_sq = arg_parser.get_str("prec_sq");
|
||||
std::string prec_kw = arg_parser.get_str("prec_kw");
|
||||
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
|
||||
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
|
||||
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
|
||||
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int gate_only = arg_parser.get_int("gate_only");
|
||||
int init = arg_parser.get_int("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
|
||||
using TypeConfig = FlatmmUkTypeConfig<I, W, O, ST, SW, SQ, KW>;
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using BDataType = ADataType;
|
||||
using AccDataType = typename TypeConfig::AccDataType;
|
||||
using CDataType = AccDataType;
|
||||
using DDataType = AccDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host({M, K});
|
||||
ck_tile::HostTensor<BDataType> b_host({N, K});
|
||||
ck_tile::HostTensor<CDataType> c_host({M, N});
|
||||
ck_tile::HostTensor<DDataType> d_host({M, N});
|
||||
|
||||
ck_tile::HostTensor<int> dbg_int({M * N, K});
|
||||
ck_tile::HostTensor<float> dbg_fp32({M * N, K});
|
||||
ck_tile::HostTensor<ck_tile::bf16_t> dbg_bf16({M * N, K});
|
||||
|
||||
if(init == 0)
|
||||
{
|
||||
ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host);
|
||||
ck_tile::FillStepRange<BDataType>{-.5f, .5f, 0.01f}(b_host);
|
||||
}
|
||||
else if(init == 1)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f, seed, true}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f, seed, true}(b_host);
|
||||
}
|
||||
else if(init == 2)
|
||||
{
|
||||
ck_tile::FillNormalDistribution<ADataType>{0.f, 1.f, seed, true}(a_host);
|
||||
ck_tile::FillNormalDistribution<BDataType>{0.f, 1.f, seed, true}(b_host);
|
||||
}
|
||||
/*
|
||||
// a_host
|
||||
{
|
||||
int X = static_cast<int>(K);
|
||||
int Y = static_cast<int>(M);
|
||||
|
||||
for(int y = 0; y < Y; y++)
|
||||
{
|
||||
for(int x = 0; x < X; x++)
|
||||
{
|
||||
int idx = X * y + x;
|
||||
a_host.mData[idx] = ck_tile::type_convert<ADataType>(x * 1.0f);
|
||||
//b_host.mData[idx] = ck_tile::type_convert<GDataType>(y * 1.0f);
|
||||
//b_host.mData[idx] = ck_tile::type_convert<GDataType>(y*1.f + x * 0.0001f);
|
||||
}
|
||||
}
|
||||
}
|
||||
// b_host
|
||||
{
|
||||
int X = static_cast<int>(K);
|
||||
int Y = static_cast<int>(N);
|
||||
|
||||
for(int y = 0; y < Y; y++)
|
||||
{
|
||||
for(int x = 0; x < X; x++)
|
||||
{
|
||||
int idx = X * y + x;
|
||||
b_host.mData[idx] = ck_tile::type_convert<GDataType>(idx * 1.0f);
|
||||
//b_host.mData[idx] = ck_tile::type_convert<GDataType>(y * 1.0f);
|
||||
//b_host.mData[idx] = ck_tile::type_convert<GDataType>(y*1.f + x * 0.0001f);
|
||||
}
|
||||
}
|
||||
}*/
|
||||
|
||||
// permute weight
|
||||
ck_tile::HostTensor<BDataType> b_perm_host = shuffle_weight(b_host, prec_w, 1);
|
||||
|
||||
ck_tile::DeviceMem a_buf(a_host);
|
||||
ck_tile::DeviceMem b_buf(b_perm_host); // b_host -> b_perm_host
|
||||
ck_tile::DeviceMem c_buf(c_host);
|
||||
ck_tile::DeviceMem d_buf(d_host);
|
||||
ck_tile::DeviceMem dbg_int_buf(dbg_int);
|
||||
ck_tile::DeviceMem dbg_bf16_buf(dbg_bf16);
|
||||
ck_tile::DeviceMem dbg_fp32_buf(dbg_fp32);
|
||||
|
||||
flatmm_uk_traits traits{prec_i,
|
||||
prec_w,
|
||||
prec_o,
|
||||
prec_st,
|
||||
prec_sw,
|
||||
prec_sq,
|
||||
prec_kw,
|
||||
block_m,
|
||||
gate_only,
|
||||
fused_quant};
|
||||
printf("[FF] --- run(): <flatmm_uk_traits> ---\n");
|
||||
printf("[FF] traits.prec_i = %s\n", traits.prec_i.c_str());
|
||||
printf("[FF] traits.prec_w = %s\n", traits.prec_w.c_str());
|
||||
printf("[FF] traits.prec_o = %s\n", traits.prec_o.c_str());
|
||||
printf("[FF] traits.prec_st = %s\n", traits.prec_st.c_str());
|
||||
printf("[FF] traits.prec_sw = %s\n", traits.prec_sw.c_str());
|
||||
printf("[FF] traits.prec_sq = %s\n", traits.prec_sq.c_str());
|
||||
printf("[FF] traits.prec_kw = %s\n", traits.prec_kw.c_str());
|
||||
printf("[FF] traits.block_m = %d\n", traits.block_m);
|
||||
printf("[FF] traits.gate_only = %d\n", traits.gate_only);
|
||||
printf("[FF] traits.fused_quant = %d\n", traits.fused_quant);
|
||||
|
||||
flatmm_uk_args args{a_buf.GetDeviceBuffer(),
|
||||
b_buf.GetDeviceBuffer(),
|
||||
c_buf.GetDeviceBuffer(),
|
||||
d_buf.GetDeviceBuffer(),
|
||||
dbg_int_buf.GetDeviceBuffer(),
|
||||
dbg_bf16_buf.GetDeviceBuffer(),
|
||||
dbg_fp32_buf.GetDeviceBuffer(),
|
||||
block_m,
|
||||
K,
|
||||
N,
|
||||
M,
|
||||
experts,
|
||||
topk,
|
||||
stride};
|
||||
printf("[FF] --- run(): <flatmm_uk_args> ---\n");
|
||||
printf("[FF] args.block_m = %d\n", args.block_m);
|
||||
printf("[FF] args.hidden_size = %d\n", args.hidden_size);
|
||||
printf("[FF] args.intermediate_size = %d\n", args.intermediate_size);
|
||||
printf("[FF] args.num_tokens = %d\n", args.num_tokens); // 1
|
||||
printf("[FF] args.topk = %d\n", args.topk); // 0
|
||||
printf("[FF] args.num_experts = %d\n", args.num_experts); // 0
|
||||
printf("[FF] args.stride_token = %d\n", args.stride_token);
|
||||
|
||||
float ave_time = flatmm_uk(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
if(ave_time < 0)
|
||||
{
|
||||
std::cout << " not supported!" << std::endl << std::flush;
|
||||
return false;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
auto d_dev = d_buf.ToHost<float>();
|
||||
std::cout << std::endl << " =================== " << std::endl;
|
||||
d_host.SetZero();
|
||||
my_reference_gemm<ADataType, BDataType, CDataType, DDataType>(
|
||||
a_host, b_host, d_host, ave_time);
|
||||
pass = ck_tile::check_err(d_dev, d_host);
|
||||
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
#if 0
|
||||
int GridDimX = 2;
|
||||
int GridDimY = 1;
|
||||
int BlockDimX = 64;
|
||||
int BlockDimY = 4;
|
||||
int BlockSize = BlockDimX * BlockDimY;
|
||||
// dbg_int
|
||||
{
|
||||
auto dbg_int_dev = dbg_int_buf.ToHost<int>();
|
||||
std::ofstream file("ff_dbg_int.txt");
|
||||
file << " [dbg_int]: Grid = [" << GridDimX << ", " << GridDimY << "], Block = " << BlockSize
|
||||
<< std::endl;
|
||||
|
||||
for(int bidy = 0; bidy < GridDimY; bidy++)
|
||||
{
|
||||
for(int bidx = 0; bidx < GridDimX; bidx++)
|
||||
{
|
||||
file << "\n ========== block : [" << bidx << ", " << bidy << "] ==========";
|
||||
for(int tid = 0; tid < BlockSize; tid++)
|
||||
{
|
||||
int gid = (BlockSize * GridDimX) * bidy + BlockSize * bidx + tid;
|
||||
if(tid % 64 == 0)
|
||||
{
|
||||
file << "\n [" << tid << " : " << tid + 63 << "]: ";
|
||||
}
|
||||
file << ck_tile::type_convert<int>(dbg_int_dev.mData[gid]) << ", ";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
// dbg_bf16 ---> kernel
|
||||
{
|
||||
auto dbg_bf16_dev = dbg_bf16_buf.ToHost<BDataType>();
|
||||
std::ofstream file("ff_dbg_bf16_kernel.txt");
|
||||
file << " [dbg_bf16]: Grid = [" << GridDimX << ", " << GridDimY
|
||||
<< "], Block = " << BlockSize << std::endl;
|
||||
|
||||
for(int bidy = 0; bidy < GridDimY; bidy++)
|
||||
{
|
||||
for(int bidx = 0; bidx < GridDimX; bidx++)
|
||||
{
|
||||
file << "\n ========== block : [" << bidx << ", " << bidy << "] ==========";
|
||||
for(int tid = 0; tid < BlockSize; tid++)
|
||||
{
|
||||
int gid = (BlockSize * bidx) * bidy + BlockSize * bidx + tid;
|
||||
|
||||
file << "\n [" << tid << "]: ";
|
||||
for(int i = 0; i < 64; i++) // multi output per thread
|
||||
file << ck_tile::type_convert<float>(dbg_bf16_dev.mData[gid * 64 + i])
|
||||
<< ", ";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
// dbg_bf16
|
||||
{
|
||||
auto dbg_bf16_dev = dbg_bf16_buf.ToHost<BDataType>();
|
||||
std::ofstream file("ff_dbg_bf16.txt");
|
||||
int X = static_cast<int>(N);
|
||||
int Y = static_cast<int>(M);
|
||||
file << " [dbg_bf16]: Row = " << Y << ", Col = " << X << std::endl;
|
||||
|
||||
for(int m = 0; m < Y; m++)
|
||||
{
|
||||
file << "\n ========== row : [" << m << " / " << Y << "] ==========";
|
||||
for(int n = 0; n < X; n++)
|
||||
{
|
||||
if(n % 64 == 0)
|
||||
{
|
||||
file << "\n [" << n << " : " << n + 63 << "]: ";
|
||||
}
|
||||
int idx = X * m + n;
|
||||
file << ck_tile::type_convert<float>(dbg_bf16_dev.mData[idx]) << ", ";
|
||||
}
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
// dbg_fp32 ---> kernel
|
||||
{
|
||||
auto dbg_fp32_dev = dbg_fp32_buf.ToHost<float>();
|
||||
std::ofstream file("ff_dbg_fp32_kernel.txt");
|
||||
file << " [dbg_fp32]: Grid = [" << GridDimX << ", " << GridDimY
|
||||
<< "], Block = " << BlockSize << std::endl;
|
||||
|
||||
for(int bidy = 0; bidy < GridDimY; bidy++)
|
||||
{
|
||||
for(int bidx = 0; bidx < GridDimX; bidx++)
|
||||
{
|
||||
file << "\n ========== block : [" << bidx << ", " << bidy << "] ==========";
|
||||
for(int tid = 0; tid < BlockSize; tid++)
|
||||
{
|
||||
int gid = (BlockSize * bidx) * bidy + BlockSize * bidx + tid;
|
||||
|
||||
file << "\n [" << tid << "]: ";
|
||||
for(int i = 0; i < 64; i++) // multi output per thread
|
||||
file << ck_tile::type_convert<float>(dbg_fp32_dev.mData[gid * 64 + i])
|
||||
<< ", ";
|
||||
|
||||
// if(tid % 64 == 0) // one output per thread
|
||||
// file << "\n [" << tid << " : " << tid + 63 << "]: ";
|
||||
// file << ck_tile::type_convert<float>(dbg_bf16.mData[gid]) << ", ";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
// dbg_fp32
|
||||
{
|
||||
auto dbg_fp32_dev = dbg_fp32_buf.ToHost<float>();
|
||||
std::ofstream file("ff_dbg_fp32.txt");
|
||||
int X = static_cast<int>(N);
|
||||
int Y = static_cast<int>(M);
|
||||
file << " [dbg_fp32]: Row = " << Y << ", Col = " << X << std::endl;
|
||||
|
||||
for(int m = 0; m < Y; m++)
|
||||
{
|
||||
file << "\n ========== row : [" << m << " / " << Y << "] ==========";
|
||||
for(int n = 0; n < X; n++)
|
||||
{
|
||||
if(n % 64 == 0)
|
||||
{
|
||||
file << "\n [" << n << " : " << n + 63 << "]: ";
|
||||
}
|
||||
int idx = X * m + n;
|
||||
file << ck_tile::type_convert<float>(dbg_fp32_dev.mData[idx]) << ", ";
|
||||
}
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
// a_host
|
||||
{
|
||||
std::ofstream file("ff_a_host.txt");
|
||||
int X = static_cast<int>(K);
|
||||
int Y = static_cast<int>(M);
|
||||
file << " [a_host]: Row = " << Y << ", Col = " << X << std::endl;
|
||||
|
||||
for(int y = 0; y < Y; y++)
|
||||
{
|
||||
file << "\n ========== row : [" << y << " / " << Y << "] ==========";
|
||||
for(int x = 0; x < X; x++)
|
||||
{
|
||||
int idx = X * y + x;
|
||||
if(idx % 16 == 0)
|
||||
{
|
||||
file << "\n [" << x << " : " << x + 15 << " ]: ";
|
||||
}
|
||||
|
||||
file << ck_tile::type_convert<float>(a_host.mData[idx]) << ", ";
|
||||
}
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
// b_host
|
||||
{
|
||||
std::ofstream file("ff_b_host.txt");
|
||||
int X = static_cast<int>(K);
|
||||
int Y = static_cast<int>(N);
|
||||
file << " [b_host]: Row = " << Y << ", Col = " << X << std::endl;
|
||||
|
||||
for(int y = 0; y < Y; y++)
|
||||
{
|
||||
file << "\n ========== row : [" << y << " / " << Y << "] ==========";
|
||||
for(int x = 0; x < X; x++)
|
||||
{
|
||||
int idx = X * y + x;
|
||||
if(idx % 16 == 0)
|
||||
{
|
||||
file << "\n [" << x << " : " << x + 15 << " ]: ";
|
||||
}
|
||||
|
||||
file << ck_tile::type_convert<float>(b_host.mData[idx]) << ", ";
|
||||
}
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
// permute_b
|
||||
{
|
||||
std::ofstream file("ff_b_perm_host.txt");
|
||||
int X = static_cast<int>(K);
|
||||
int Y = static_cast<int>(N);
|
||||
file << " [b_perm_host]: Row = " << Y << ", Col = " << X << std::endl;
|
||||
|
||||
for(int y = 0; y < Y; y++)
|
||||
{
|
||||
file << "\n ========== row : [" << y << " / " << Y << "] ==========";
|
||||
for(int x = 0; x < X; x++)
|
||||
{
|
||||
int idx = X * y + x;
|
||||
if(idx % 16 == 0)
|
||||
{
|
||||
file << "\n [" << x << " : " << x + 15 << " ]: ";
|
||||
}
|
||||
|
||||
file << ck_tile::type_convert<float>(b_perm_host.mData[idx]) << ", ";
|
||||
}
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
// d_dev ---> kernel
|
||||
{
|
||||
auto d_dev = d_buf.ToHost<float>();
|
||||
std::ofstream file("ff_d_dev_kernel.txt");
|
||||
file << " [d_dev]: Grid = [" << GridDimX << ", " << GridDimY << "], Block = " << BlockSize
|
||||
<< std::endl;
|
||||
|
||||
for(int bidy = 0; bidy < GridDimY; bidy++)
|
||||
{
|
||||
for(int bidx = 0; bidx < GridDimX; bidx++)
|
||||
{
|
||||
file << "\n ========== block : [" << bidx << ", " << bidy << "] ==========";
|
||||
for(int tid = 0; tid < BlockSize; tid++)
|
||||
{
|
||||
int gid = (BlockSize * bidx) * bidy + BlockSize * bidx + tid;
|
||||
|
||||
file << "\n [" << tid << "]: ";
|
||||
for(int i = 0; i < 64; i++) // multi output per thread
|
||||
file << ck_tile::type_convert<float>(d_dev.mData[gid * 64 + i]) << ", ";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
// d_dev
|
||||
{
|
||||
auto d_dev = d_buf.ToHost<float>();
|
||||
std::ofstream file("ff_d_dev.txt");
|
||||
int X = static_cast<int>(N);
|
||||
int Y = static_cast<int>(M);
|
||||
file << " [d_dev]: Row = " << Y << ", Col = " << X << std::endl;
|
||||
|
||||
for(int y = 0; y < Y; y++)
|
||||
{
|
||||
file << "\n ========== row : [" << y << " / " << Y << "] ==========";
|
||||
for(int x = 0; x < X; x++)
|
||||
{
|
||||
if(x % 64 == 0)
|
||||
{
|
||||
file << "\n [" << x << " : " << x + 63 << "]: ";
|
||||
}
|
||||
int idx = X * y + x;
|
||||
file << ck_tile::type_convert<float>(d_dev.mData[idx]) << ", ";
|
||||
}
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
// d_host
|
||||
{
|
||||
std::ofstream file("ff_d_host.txt");
|
||||
int X = static_cast<int>(N);
|
||||
int Y = static_cast<int>(M);
|
||||
file << " [d_host]: Row = " << Y << ", Col = " << X << std::endl;
|
||||
|
||||
for(int y = 0; y < Y; y++)
|
||||
{
|
||||
file << "\n ========== row : [" << y << " / " << Y << "] ==========";
|
||||
for(int x = 0; x < X; x++)
|
||||
{
|
||||
if(x % 64 == 0)
|
||||
{
|
||||
file << "\n [" << x << " : " << x + 63 << "]: ";
|
||||
}
|
||||
int idx = X * y + x;
|
||||
file << ck_tile::type_convert<float>(d_host.mData[idx]) << ", ";
|
||||
}
|
||||
}
|
||||
|
||||
file.close();
|
||||
}
|
||||
#endif
|
||||
|
||||
std::cout << std::flush << std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_w = arg_parser.get_str("prec_w");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_st = arg_parser.get_str("prec_st");
|
||||
std::string prec_sw = arg_parser.get_str("prec_sw");
|
||||
std::string prec_sq = arg_parser.get_str("prec_sq");
|
||||
std::string prec_kw = arg_parser.get_str("prec_kw");
|
||||
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
|
||||
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
|
||||
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
|
||||
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
|
||||
|
||||
// no dynamic quant case
|
||||
if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32")
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float>(
|
||||
arg_parser)
|
||||
? 0
|
||||
: -2;
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_w == "fp16" && prec_o == "fp16" && prec_kw == "fp32")
|
||||
{
|
||||
return run<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float>(
|
||||
arg_parser)
|
||||
? 0
|
||||
: -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
@@ -17,3 +17,4 @@ add_subdirectory(14_moe_smoothquant)
|
||||
add_subdirectory(15_fused_moe)
|
||||
add_subdirectory(16_batched_gemm)
|
||||
add_subdirectory(17_grouped_gemm)
|
||||
add_subdirectory(18_flatmm_uk)
|
||||
|
||||
@@ -0,0 +1,665 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A async load to LDS, B direct to AGPR
|
||||
// B matrix preshuffled in br*kr*w
|
||||
// require 4 wave, occupancy=1c
|
||||
// agpr useage:256
|
||||
// vgpr usage:64(A local) + 64(acc) + 8(os_a) + 8(os_b) = 144 (rem:112)
|
||||
//
|
||||
// for this gemm, 4 16x16x16 transposed layout
|
||||
// input A vpgpr layout
|
||||
// v0-v15: [ 0:15](gemm_m)x128(gemm_k)
|
||||
// v16-v31: [16:31](gemm_m)x128(gemm_k)
|
||||
|
||||
// input B vpgpr layout
|
||||
// v0-v15: [ 0: 15](gemm_n)x128(gemm_k)
|
||||
// v16-v31: [ 64: 79](gemm_n)x128(gemm_k)
|
||||
// ......................
|
||||
// v111-v127: [448:463](gemm_n)x128(gemm_k)
|
||||
|
||||
// output C vpgpr layout
|
||||
// v0-v3 : [ 0:15](gemm_m)x[ 0: 15](gemm_n)
|
||||
// v4-v7 : [16:31](gemm_m)x[ 0: 15](gemm_n)
|
||||
// v8-v11: [ 0:15](gemm_m)x[64: 79](gemm_n)
|
||||
// v12-v15: [16:31](gemm_m)x[64: 79](gemm_n)
|
||||
// ......................
|
||||
// v56-v59: [ 0:15](gemm_m)x[448:463](gemm_n)
|
||||
// v60-v63: [16:31](gemm_m)x[448:463](gemm_n)
|
||||
struct Flatmm_ff_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
{
|
||||
static constexpr index_t Block_M = 32;
|
||||
static constexpr index_t Block_N = 512;
|
||||
static constexpr index_t Block_K = 128;
|
||||
|
||||
static constexpr index_t WarpPerBlock_M = 1;
|
||||
static constexpr index_t WarpPerBlock_N = 4;
|
||||
static constexpr index_t WarpPerBlock_K = 1;
|
||||
|
||||
static constexpr index_t NumWarps = 4;
|
||||
|
||||
static constexpr index_t Warp_M = 16;
|
||||
static constexpr index_t Warp_N = 16;
|
||||
static constexpr index_t Warp_K = 32; // 16 * SubKPacks
|
||||
|
||||
static constexpr index_t BlockSize = 256;
|
||||
|
||||
static constexpr index_t SubKPacks = 2; // this is used to gurantee every threads can do dwordx4
|
||||
|
||||
// TODO: note Nr/Kr/W need consider SubKPacks
|
||||
static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
|
||||
static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
|
||||
static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
|
||||
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8
|
||||
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4
|
||||
|
||||
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<2, 1>, // !! note here is different
|
||||
sequence<0, 0>>{};
|
||||
|
||||
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution;
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
return c_block_dstr;
|
||||
}
|
||||
|
||||
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
|
||||
{
|
||||
using CDataType = float;
|
||||
constexpr auto c_block_dstr = MakeCBlockDist();
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
|
||||
{
|
||||
// A async->LDS
|
||||
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
|
||||
constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
|
||||
constexpr index_t KPad = KPack_; // pad between warps
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
|
||||
if constexpr(LanesPerK >= warpSize)
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % warpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / warpSize;
|
||||
if constexpr(wavesPerK > NumWarps)
|
||||
{
|
||||
// TODO: need multiple issues along K to load all data
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t wavesPerM = NumWarps / wavesPerK;
|
||||
constexpr index_t NumIssues = Block_M / wavesPerM;
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<wavesPerM>{}, // m1
|
||||
number<wavesPerK>{}, // k0
|
||||
number<warpSize>{}, // k1
|
||||
number<KVector>{}), // k2
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // k0
|
||||
number<KVector>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<NumIssues>{}),
|
||||
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
|
||||
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
return lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(warpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
|
||||
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<LaneGroups>{}, // m1
|
||||
number<NumWarps>{}, // m2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<Block_K>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // m2
|
||||
number<KVector>{}, // k0
|
||||
number<1>{}), // k1
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<NumIssues>{}),
|
||||
make_pass_through_transform(number<NumWarps>{}),
|
||||
make_merge_transform(make_tuple(
|
||||
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
return lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
}
|
||||
|
||||
// template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
|
||||
{
|
||||
// load from LDS to register, every wave has same layout
|
||||
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
|
||||
constexpr index_t KPad = KPack_; // pad between warps
|
||||
|
||||
constexpr index_t kAMLane = 16;
|
||||
constexpr index_t kABKLane = 4;
|
||||
constexpr index_t kABKPerLane = 4;
|
||||
constexpr index_t kKIter = 2;
|
||||
static_assert(KPack_ == (kABKPerLane * kKIter));
|
||||
|
||||
constexpr auto lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<Repeat_M>{}, // m0 y
|
||||
number<kAMLane>{}, // m1 p
|
||||
number<Repeat_K>{}, // k0 y
|
||||
number<kABKLane>{}, // k1 p
|
||||
number<KPack_>{}), // k2 y-vector
|
||||
make_tuple(number<kAMLane*(Block_K + KPad)>{}, // m0
|
||||
number<Block_K + KPad>{}, // m1
|
||||
number<kABKLane * KPack_>{}, // k0
|
||||
number<KPack_>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KPack_>{}, // lds load vector
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(make_tuple(number<Repeat_M>{}, number<kAMLane>{})),
|
||||
make_merge_transform(
|
||||
make_tuple(number<Repeat_K>{}, number<kABKLane>{}, number<KPack_>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_desc_m_k;
|
||||
}
|
||||
|
||||
static constexpr auto GetGemm_AWarpEnc()
|
||||
{
|
||||
constexpr index_t kAMLane = 16;
|
||||
constexpr index_t kABKLane = 4;
|
||||
constexpr index_t kABKPerLane = 4;
|
||||
constexpr index_t kKIter = 2;
|
||||
|
||||
using enc_ = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kAMLane>, sequence<kABKLane, kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
return enc_{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return 32 * (128 + 8) * sizeof(bf16_t);
|
||||
}
|
||||
};
|
||||
|
||||
struct Flatmm_ff_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_ff_32x512x128_1x4x1_16x16x32_Base
|
||||
{
|
||||
using ADataType = bf16_t;
|
||||
using BDataType = bf16_t;
|
||||
|
||||
// TODO: need paired with tile_window_linear!
|
||||
// TODO: need call init_raw() before call this function!
|
||||
template <typename ARes, typename ACoords, typename BRes, typename BCoords>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const ARes& res_a,
|
||||
const ACoords& cached_coords_a,
|
||||
const BRes& res_b,
|
||||
const BCoords& cached_coords_b,
|
||||
CK_TILE_LDS_ADDR void* smem,
|
||||
index_t k,
|
||||
index_t tile_offset_a, // for each tile, the offset to move for each unroll
|
||||
index_t tile_offset_b, // for each tile, the offset to move for each unroll
|
||||
int * dbg_int,
|
||||
short* dbg_bf16,
|
||||
float* dbg_fp32)
|
||||
{
|
||||
#if 0
|
||||
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
|
||||
{
|
||||
printf("[UK] Flatmm_ff_32x512x128_1x4x1_16x16x32_BF16 =====\n");
|
||||
}
|
||||
|
||||
[[maybe_unused]] uint32_t tidx = threadIdx.x; // 0~255
|
||||
[[maybe_unused]] uint32_t tidy = threadIdx.y; // 0~0
|
||||
[[maybe_unused]] uint32_t bidx = blockIdx.x; // 0~1
|
||||
[[maybe_unused]] uint32_t bidy = blockIdx.y; // 0~51
|
||||
[[maybe_unused]] uint32_t bdmx = blockDim.x; // 256
|
||||
[[maybe_unused]] uint32_t bdmy = blockDim.y; // 1
|
||||
[[maybe_unused]] uint32_t gdmx = gridDim.x; // 2
|
||||
[[maybe_unused]] uint32_t gdmy = gridDim.y; // 52
|
||||
[[maybe_unused]] uint32_t gid = ((bdmx * bdmy) * gdmx) * bidy
|
||||
+ (bdmx * bdmy) * bidx
|
||||
+ bdmx * tidy
|
||||
+ tidx;
|
||||
#endif
|
||||
(void)dbg_int;
|
||||
(void)dbg_bf16;
|
||||
(void)dbg_fp32;
|
||||
|
||||
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
|
||||
static_assert(BCoords::size() == Repeat_N);
|
||||
|
||||
auto a_sst = make_tile_window(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
|
||||
MakeLdsStoreDesc_A().get_lengths(),
|
||||
{0, 0, 0});
|
||||
|
||||
auto a_sld = [&]() {
|
||||
constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
|
||||
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
|
||||
sequence<WarpPerBlock_N>,
|
||||
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
|
||||
return make_tile_window_linear(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
|
||||
MakeLdsLoadDesc_A().get_lengths(),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(a_block_dstr_encode));
|
||||
}();
|
||||
|
||||
const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
|
||||
const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
|
||||
|
||||
const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
|
||||
constexpr auto smem_buf_size =
|
||||
MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
|
||||
static_assert(a_sld.get_num_of_access() == 8);
|
||||
constexpr auto sld_os = generate_tuple(
|
||||
[&](auto i_access) {
|
||||
return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
|
||||
},
|
||||
number<a_sld.get_num_of_access()>{});
|
||||
|
||||
index_t loop_cnt = k / Block_K;
|
||||
|
||||
// this is the acc thread buffer
|
||||
fp32x4_t v_acc[16]{.0f};
|
||||
// float dbg_dword = 0.0f;
|
||||
|
||||
#pragma region ASM
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Winline-asm"
|
||||
// clang-format off
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
#include "uk/flatmm_ff_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
|
||||
#undef CK_TILE_FLATMM_UK_MFMA
|
||||
: [s_loop_cnt]"+s"(loop_cnt),
|
||||
[v_acc_0] "+v"(v_acc[0]),
|
||||
[v_acc_1] "+v"(v_acc[1]),
|
||||
[v_acc_2] "+v"(v_acc[2]),
|
||||
[v_acc_3] "+v"(v_acc[3]),
|
||||
[v_acc_4] "+v"(v_acc[4]),
|
||||
[v_acc_5] "+v"(v_acc[5]),
|
||||
[v_acc_6] "+v"(v_acc[6]),
|
||||
[v_acc_7] "+v"(v_acc[7]),
|
||||
[v_acc_8] "+v"(v_acc[8]),
|
||||
[v_acc_9] "+v"(v_acc[9]),
|
||||
[v_acc_10]"+v"(v_acc[10]),
|
||||
[v_acc_11]"+v"(v_acc[11]),
|
||||
[v_acc_12]"+v"(v_acc[12]),
|
||||
[v_acc_13]"+v"(v_acc[13]),
|
||||
[v_acc_14]"+v"(v_acc[14]),
|
||||
[v_acc_15]"+v"(v_acc[15]),
|
||||
//[v_dbg]"+v"(dbg_dword),
|
||||
[s_mem_]"+r"(smem)
|
||||
: [s_res_a0]"s"(res_a[0]),
|
||||
[s_res_a1]"s"(res_a[1]),
|
||||
[s_res_a2]"s"(res_a[2]),
|
||||
[s_res_a3]"s"(res_a[3]),
|
||||
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),
|
||||
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),
|
||||
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),
|
||||
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))),
|
||||
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))),
|
||||
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))),
|
||||
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))),
|
||||
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))),
|
||||
|
||||
[s_res_b0]"s"(res_b[0]),
|
||||
[s_res_b1]"s"(res_b[1]),
|
||||
[s_res_b2]"s"(res_b[2]),
|
||||
[s_res_b3]"s"(res_b[3]),
|
||||
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
|
||||
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
|
||||
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
|
||||
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
|
||||
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
|
||||
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
|
||||
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
|
||||
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
|
||||
|
||||
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),
|
||||
[s_m0_init]"s"(m0_init_value),
|
||||
[s_size_per_issue]"s"(size_per_issue),
|
||||
[smem_sz]"n"(smem_buf_size), //(smem_buf_size),
|
||||
[sld_os_0]"n"(sld_os[number<0>{}].value),
|
||||
[sld_os_1]"n"(sld_os[number<1>{}].value),
|
||||
[sld_os_2]"n"(sld_os[number<2>{}].value),
|
||||
[sld_os_3]"n"(sld_os[number<3>{}].value),
|
||||
[sld_os_4]"n"(sld_os[number<4>{}].value),
|
||||
[sld_os_5]"n"(sld_os[number<5>{}].value),
|
||||
[sld_os_6]"n"(sld_os[number<6>{}].value),
|
||||
[sld_os_7]"n"(sld_os[number<7>{}].value),
|
||||
[s_tile_os_a]"s"(tile_offset_a_bytes),
|
||||
[s_tile_os_b]"s"(tile_offset_b_bytes)
|
||||
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
|
||||
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
|
||||
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
|
||||
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
|
||||
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
|
||||
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
|
||||
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
|
||||
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
|
||||
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
|
||||
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
|
||||
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
|
||||
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
|
||||
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
|
||||
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
|
||||
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
|
||||
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
|
||||
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
|
||||
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
|
||||
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
|
||||
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
|
||||
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
|
||||
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
|
||||
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
|
||||
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
|
||||
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
|
||||
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
|
||||
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
|
||||
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
|
||||
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
|
||||
"a252", "a253", "a254", "a255",
|
||||
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
|
||||
"s86", // s86 as tmp
|
||||
"v64", "v65", "v66", "v67", "v68", "v69",
|
||||
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79",
|
||||
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89",
|
||||
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99",
|
||||
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107",
|
||||
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115",
|
||||
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123",
|
||||
"v124", "v125", "v126", "v127"
|
||||
);
|
||||
// clang-format on
|
||||
#pragma clang diagnostic pop
|
||||
#pragma endregion
|
||||
|
||||
// return local scratch
|
||||
auto c = MakeCBlockTile();
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
|
||||
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
|
||||
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
|
||||
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
|
||||
}
|
||||
|
||||
/*float * vacc0x = reinterpret_cast<float*>(v_acc);
|
||||
short * pdbgf16 = reinterpret_cast<short*>(vacc0x);
|
||||
//short * pdbg_u8 = reinterpret_cast<short*>(&dbg_dword);
|
||||
int dbgCntPerThd = 8;
|
||||
for(int i = 0; i < dbgCntPerThd; i++)
|
||||
{
|
||||
dbg_bf16[gid * dbgCntPerThd + i] = *(pdbgf16 + i);
|
||||
//dbg_bf16[gid * dbgCntPerThd + i] = *(pdbg_u8 + i);
|
||||
}
|
||||
int * pdbg_i32 = reinterpret_cast<int*>(&dbg_dword);
|
||||
dbg_int[gid] = *(pdbg_i32);*/
|
||||
return c;
|
||||
}
|
||||
};
|
||||
|
||||
struct Flatmm_ff_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_ff_32x512x128_1x4x1_16x16x32_Base
|
||||
{
|
||||
using ADataType = fp16_t;
|
||||
using BDataType = fp16_t;
|
||||
|
||||
// TODO: need paired with tile_window_linear!
|
||||
// TODO: need call init_raw() before call this function!
|
||||
template <typename ARes, typename ACoords, typename BRes, typename BCoords>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const ARes& res_a,
|
||||
const ACoords& cached_coords_a,
|
||||
const BRes& res_b,
|
||||
const BCoords& cached_coords_b,
|
||||
CK_TILE_LDS_ADDR void* smem,
|
||||
index_t k,
|
||||
index_t tile_offset_a, // for each tile, the offset to move for each unroll
|
||||
index_t tile_offset_b,
|
||||
int * dbg_int,
|
||||
short* dbg_bf16,
|
||||
float* dbg_fp32)
|
||||
{
|
||||
(void)dbg_int;
|
||||
(void)dbg_fp32;
|
||||
(void)dbg_bf16;
|
||||
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
|
||||
static_assert(BCoords::size() == Repeat_N);
|
||||
|
||||
auto a_sst = make_tile_window(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
|
||||
MakeLdsStoreDesc_A().get_lengths(),
|
||||
{0, 0, 0});
|
||||
|
||||
auto a_sld = [&]() {
|
||||
constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
|
||||
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
|
||||
sequence<WarpPerBlock_N>,
|
||||
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
|
||||
return make_tile_window_linear(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
|
||||
MakeLdsLoadDesc_A().get_lengths(),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(a_block_dstr_encode));
|
||||
}();
|
||||
|
||||
const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
|
||||
const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
|
||||
|
||||
const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
|
||||
constexpr auto smem_buf_size =
|
||||
MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
|
||||
static_assert(a_sld.get_num_of_access() == 8);
|
||||
constexpr auto sld_os = generate_tuple(
|
||||
[&](auto i_access) {
|
||||
return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
|
||||
},
|
||||
number<a_sld.get_num_of_access()>{});
|
||||
|
||||
index_t loop_cnt = k / Block_K;
|
||||
|
||||
// this is the acc thread buffer
|
||||
fp32x4_t v_acc[16]{.0f};
|
||||
float dbg_dword = 0.0f;
|
||||
|
||||
// B nr->kr
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Winline-asm"
|
||||
// clang-format off
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
|
||||
#include "uk/flatmm_ff_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
|
||||
#undef CK_TILE_FLATMM_UK_MFMA
|
||||
: [s_loop_cnt]"+s"(loop_cnt),
|
||||
[v_acc_0]"+v"(v_acc[0]),
|
||||
[v_acc_1]"+v"(v_acc[1]),
|
||||
[v_acc_2]"+v"(v_acc[2]),
|
||||
[v_acc_3]"+v"(v_acc[3]),
|
||||
[v_acc_4]"+v"(v_acc[4]),
|
||||
[v_acc_5]"+v"(v_acc[5]),
|
||||
[v_acc_6]"+v"(v_acc[6]),
|
||||
[v_acc_7]"+v"(v_acc[7]),
|
||||
[v_acc_8]"+v"(v_acc[8]),
|
||||
[v_acc_9]"+v"(v_acc[9]),
|
||||
[v_acc_10]"+v"(v_acc[10]),
|
||||
[v_acc_11]"+v"(v_acc[11]),
|
||||
[v_acc_12]"+v"(v_acc[12]),
|
||||
[v_acc_13]"+v"(v_acc[13]),
|
||||
[v_acc_14]"+v"(v_acc[14]),
|
||||
[v_acc_15]"+v"(v_acc[15]),
|
||||
[s_mem_]"+r"(smem)
|
||||
: [s_res_a0]"s"(res_a[0]),
|
||||
[s_res_a1]"s"(res_a[1]),
|
||||
[s_res_a2]"s"(res_a[2]),
|
||||
[s_res_a3]"s"(res_a[3]),
|
||||
[s_res_b0]"s"(res_b[0]),
|
||||
[s_res_b1]"s"(res_b[1]),
|
||||
[s_res_b2]"s"(res_b[2]),
|
||||
[s_res_b3]"s"(res_b[3]),
|
||||
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),
|
||||
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),
|
||||
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),
|
||||
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))),
|
||||
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))),
|
||||
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))),
|
||||
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))),
|
||||
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))),
|
||||
|
||||
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
|
||||
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
|
||||
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
|
||||
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
|
||||
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
|
||||
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
|
||||
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
|
||||
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
|
||||
[v_dbg]"v"(dbg_dword),
|
||||
|
||||
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),
|
||||
[s_m0_init]"s"(m0_init_value),
|
||||
[s_size_per_issue]"s"(size_per_issue),
|
||||
[smem_sz]"n"(smem_buf_size), //(smem_buf_size),
|
||||
[sld_os_0]"n"(sld_os[number<0>{}].value),
|
||||
[sld_os_1]"n"(sld_os[number<1>{}].value),
|
||||
[sld_os_2]"n"(sld_os[number<2>{}].value),
|
||||
[sld_os_3]"n"(sld_os[number<3>{}].value),
|
||||
[sld_os_4]"n"(sld_os[number<4>{}].value),
|
||||
[sld_os_5]"n"(sld_os[number<5>{}].value),
|
||||
[sld_os_6]"n"(sld_os[number<6>{}].value),
|
||||
[sld_os_7]"n"(sld_os[number<7>{}].value),
|
||||
[s_tile_os_a]"s"(tile_offset_a_bytes),
|
||||
[s_tile_os_b]"s"(tile_offset_b_bytes)
|
||||
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
|
||||
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
|
||||
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
|
||||
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
|
||||
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
|
||||
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
|
||||
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
|
||||
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
|
||||
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
|
||||
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
|
||||
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
|
||||
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
|
||||
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
|
||||
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
|
||||
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
|
||||
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
|
||||
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
|
||||
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
|
||||
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
|
||||
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
|
||||
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
|
||||
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
|
||||
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
|
||||
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
|
||||
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
|
||||
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
|
||||
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
|
||||
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
|
||||
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
|
||||
"a252", "a253", "a254", "a255",
|
||||
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
|
||||
"s86", // s86 as tmp
|
||||
"v64", "v65", "v66", "v67", "v68", "v69",
|
||||
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79",
|
||||
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89",
|
||||
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99",
|
||||
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107",
|
||||
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115",
|
||||
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123",
|
||||
"v124", "v125", "v126", "v127"
|
||||
);
|
||||
// clang-format on
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
// return local scratch
|
||||
auto c = MakeCBlockTile();
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
|
||||
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
|
||||
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
|
||||
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,574 @@
|
||||
#ifndef CK_TILE_FLATMM_UK_MFMA
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
#endif
|
||||
|
||||
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
#define _UK_MFMA_ "v_mfma_f32_16x16x16_bf16"
|
||||
#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16
|
||||
#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16"
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
"s_mov_b32 s16, %[s_res_a0] \n"
|
||||
"s_mov_b32 s17, %[s_res_a1] \n"
|
||||
"s_mov_b32 s18, %[s_res_a2] \n"
|
||||
"s_mov_b32 s19, %[s_res_a3] \n"
|
||||
"s_mov_b32 s20, %[s_res_b0] \n"
|
||||
"s_mov_b32 s21, %[s_res_b1] \n"
|
||||
"s_mov_b32 s22, %[s_res_b2] \n"
|
||||
"s_mov_b32 s23, %[s_res_b3] \n"
|
||||
// "s_nop 4\n"
|
||||
"; -- prefetch A0\n"
|
||||
"s_add_u32 m0, 0, %[s_m0_init] \n"
|
||||
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n" // size_per_issue = 1088
|
||||
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
|
||||
|
||||
"s_add_u32 m0, %[smem_sz], %[s_m0_init] \n" // smem_sz = 8688
|
||||
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move a with cond \n"
|
||||
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n" // tile_offset_a_bytes = 256 = 128(bf16)K
|
||||
"s_add_u32 s16, s86, s16 ; move a with cond \n"
|
||||
"s_addc_u32 s17, 0, s17 ; move a with cond \n"
|
||||
|
||||
"; -- prefetch A1\n"
|
||||
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
|
||||
|
||||
"s_add_u32 m0, 0, %[s_m0_init] \n"
|
||||
"s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
|
||||
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n"
|
||||
"s_add_u32 s16, s86, s16 ; move a with cond \n"
|
||||
"s_addc_u32 s17, 0, s17 ; move a with cond \n"
|
||||
|
||||
"; -- prefetch B0\n"
|
||||
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n" // bf16 * 8
|
||||
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
|
||||
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
|
||||
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
|
||||
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n"
|
||||
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
|
||||
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
|
||||
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
|
||||
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n"
|
||||
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
|
||||
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
|
||||
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
|
||||
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n"
|
||||
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
|
||||
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
|
||||
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
|
||||
"buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n"
|
||||
"buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
|
||||
"buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
|
||||
"buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
|
||||
"buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n"
|
||||
"buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
|
||||
"buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
|
||||
"buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
|
||||
"buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n"
|
||||
"buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
|
||||
"buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
|
||||
"buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
|
||||
"buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n"
|
||||
"buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
|
||||
"buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
|
||||
"buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
|
||||
|
||||
#if 0
|
||||
// test a
|
||||
//" s_waitcnt vmcnt(0) & lgkmcnt(0) \n"
|
||||
//" s_barrier \n"
|
||||
//"ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]\n" // 1024: N stride, 64 // K stride
|
||||
//"ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]\n"
|
||||
//"ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]\n"
|
||||
//"ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]\n"
|
||||
//"ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]\n"
|
||||
//"ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]\n"
|
||||
//"ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]\n"
|
||||
//"ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]\n"
|
||||
//" s_waitcnt vmcnt(0) & lgkmcnt(0) \n"
|
||||
//" s_barrier \n"
|
||||
//"v_add_u32 %[v_dbg], %[v_os_slda], %[sld_os_0] \n"
|
||||
//"v_mov_b32 %[v_dbg], v64 \n"
|
||||
|
||||
// test b
|
||||
" s_waitcnt vmcnt(0) & lgkmcnt(0) \n"
|
||||
" s_barrier \n"
|
||||
//"s_add_u32 s20, s86, s20 ; move b with cond \n"
|
||||
//"s_addc_u32 s21, 0, s21 ; move b with cond \n"
|
||||
"buffer_load_dwordx4 %[v_acc_0], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
|
||||
" s_waitcnt vmcnt(0) & lgkmcnt(0) \n"
|
||||
" s_barrier \n"
|
||||
"v_mov_b32 %[v_dbg], %[v_os_b2] \n"
|
||||
#else
|
||||
|
||||
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
|
||||
"s_cselect_b32 s86, %[s_tile_os_b], 0 ; move b with cond \n" // s_tile_os_b = 4096B = 2048(bf16)K
|
||||
"s_add_u32 s20, s86, s20 ; move b with cond \n"
|
||||
"s_addc_u32 s21, 0, s21 ; move b with cond \n"
|
||||
"s_waitcnt vmcnt(40) \n"
|
||||
"s_barrier \n"
|
||||
|
||||
"ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]\n" // 1024: N stride, 64 // K stride
|
||||
"ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]\n"
|
||||
"ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]\n"
|
||||
"ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]\n"
|
||||
"ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]\n"
|
||||
"ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]\n"
|
||||
"ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]\n"
|
||||
"ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]\n"
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
"L_start%=: \n"
|
||||
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
|
||||
" s_barrier \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[0:1], v[64:65], %[v_acc_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[2:3], v[66:67], %[v_acc_0] \n"
|
||||
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[4:5], v[68:69], %[v_acc_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[6:7], v[70:71], %[v_acc_0] \n"
|
||||
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[8:9], v[72:73], %[v_acc_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[10:11], v[74:75], %[v_acc_0] \n"
|
||||
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[12:13], v[76:77], %[v_acc_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[14:15], v[78:79], %[v_acc_0] \n"
|
||||
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
|
||||
_UK_MFMA_ " %[v_acc_1], acc[0:1], v[80:81], %[v_acc_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[2:3], v[82:83], %[v_acc_1] \n"
|
||||
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[4:5], v[84:85], %[v_acc_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[6:7], v[86:87], %[v_acc_1] \n"
|
||||
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[8:9], v[88:89], %[v_acc_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[10:11], v[90:91], %[v_acc_1] \n"
|
||||
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[12:13], v[92:93], %[v_acc_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[14:15], v[94:95], %[v_acc_1] \n"
|
||||
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
|
||||
_UK_MFMA_ " %[v_acc_2], acc[16:17], v[64:65], %[v_acc_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[18:19], v[66:67], %[v_acc_2] \n"
|
||||
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[20:21], v[68:69], %[v_acc_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[22:23], v[70:71], %[v_acc_2] \n"
|
||||
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[24:25], v[72:73], %[v_acc_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[26:27], v[74:75], %[v_acc_2] \n"
|
||||
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[28:29], v[76:77], %[v_acc_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[30:31], v[78:79], %[v_acc_2] \n"
|
||||
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
|
||||
_UK_MFMA_ " %[v_acc_3], acc[16:17], v[80:81], %[v_acc_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[18:19], v[82:83], %[v_acc_3] \n"
|
||||
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[20:21], v[84:85], %[v_acc_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[22:23], v[86:87], %[v_acc_3] \n"
|
||||
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[24:25], v[88:89], %[v_acc_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[26:27], v[90:91], %[v_acc_3] \n"
|
||||
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[28:29], v[92:93], %[v_acc_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[30:31], v[94:95], %[v_acc_3] \n"
|
||||
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[smem_sz], %[s_m0_init] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
|
||||
_UK_MFMA_ " %[v_acc_4], acc[32:33], v[64:65], %[v_acc_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[34:35], v[66:67], %[v_acc_4] \n"
|
||||
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[36:37], v[68:69], %[v_acc_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[38:39], v[70:71], %[v_acc_4] \n"
|
||||
" ds_read_b128 v[96:99], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[40:41], v[72:73], %[v_acc_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[42:43], v[74:75], %[v_acc_4] \n"
|
||||
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[44:45], v[76:77], %[v_acc_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[46:47], v[78:79], %[v_acc_4] \n"
|
||||
" ds_read_b128 v[100:103], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_1] \n"
|
||||
|
||||
_UK_MFMA_ " %[v_acc_5], acc[32:33], v[80:81], %[v_acc_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[34:35], v[82:83], %[v_acc_5] \n"
|
||||
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[36:37], v[84:85], %[v_acc_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[38:39], v[86:87], %[v_acc_5] \n"
|
||||
" ds_read_b128 v[104:107], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[40:41], v[88:89], %[v_acc_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[42:43], v[90:91], %[v_acc_5] \n"
|
||||
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[44:45], v[92:93], %[v_acc_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[46:47], v[94:95], %[v_acc_5] \n"
|
||||
" ds_read_b128 v[108:111], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_3] \n"
|
||||
|
||||
_UK_MFMA_ " %[v_acc_6], acc[48:49], v[64:65], %[v_acc_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[50:51], v[66:67], %[v_acc_6] \n"
|
||||
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[52:53], v[68:69], %[v_acc_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[54:55], v[70:71], %[v_acc_6] \n"
|
||||
" ds_read_b128 v[112:115], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[56:57], v[72:73], %[v_acc_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[58:59], v[74:75], %[v_acc_6] \n"
|
||||
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[60:61], v[76:77], %[v_acc_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[62:63], v[78:79], %[v_acc_6] \n"
|
||||
" ds_read_b128 v[116:119], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_5] \n"
|
||||
|
||||
_UK_MFMA_ " %[v_acc_7], acc[48:49], v[80:81], %[v_acc_7] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[50:51], v[82:83], %[v_acc_7] \n"
|
||||
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[52:53], v[84:85], %[v_acc_7] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[54:55], v[86:87], %[v_acc_7] \n"
|
||||
" ds_read_b128 v[120:123], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[56:57], v[88:89], %[v_acc_7] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[58:59], v[90:91], %[v_acc_7] \n"
|
||||
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[60:61], v[92:93], %[v_acc_7] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[62:63], v[94:95], %[v_acc_7] \n"
|
||||
" ds_read_b128 v[124:127], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_7] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
|
||||
_UK_MFMA_ " %[v_acc_8], acc[64:65], v[64:65], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[66:67], v[66:67], %[v_acc_8] \n"
|
||||
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[68:69], v[68:69], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[70:71], v[70:71], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[72:73], v[72:73], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[74:75], v[74:75], %[v_acc_8] \n"
|
||||
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[76:77], v[76:77], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[78:79], v[78:79], %[v_acc_8] \n"
|
||||
|
||||
_UK_MFMA_ " %[v_acc_9], acc[64:65], v[80:81], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[66:67], v[82:83], %[v_acc_9] \n"
|
||||
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[68:69], v[84:85], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[70:71], v[86:87], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[72:73], v[88:89], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[74:75], v[90:91], %[v_acc_9] \n"
|
||||
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[76:77], v[92:93], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[78:79], v[94:95], %[v_acc_9] \n"
|
||||
|
||||
_UK_MFMA_ " %[v_acc_10], acc[80:81], v[64:65], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[82:83], v[66:67], %[v_acc_10] \n"
|
||||
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[84:85], v[68:69], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[86:87], v[70:71], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[88:89], v[72:73], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[90:91], v[74:75], %[v_acc_10] \n"
|
||||
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[92:93], v[76:77], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[94:95], v[78:79], %[v_acc_10] \n"
|
||||
|
||||
_UK_MFMA_ " %[v_acc_11], acc[80:81], v[80:81], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[82:83], v[82:83], %[v_acc_11] \n"
|
||||
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[84:85], v[84:85], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[86:87], v[86:87], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[88:89], v[88:89], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[90:91], v[90:91], %[v_acc_11] \n"
|
||||
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[92:93], v[92:93], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[94:95], v[94:95], %[v_acc_11] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
|
||||
_UK_MFMA_ " %[v_acc_12], acc[96:97], v[64:65], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[98:99], v[66:67], %[v_acc_12] \n"
|
||||
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[100:101], v[68:69], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[102:103], v[70:71], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[104:105], v[72:73], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[106:107], v[74:75], %[v_acc_12] \n"
|
||||
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[108:109], v[76:77], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[110:111], v[78:79], %[v_acc_12] \n"
|
||||
|
||||
_UK_MFMA_ " %[v_acc_13], acc[96:97], v[80:81], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[98:99], v[82:83], %[v_acc_13] \n"
|
||||
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[100:101], v[84:85], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[102:103], v[86:87], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[104:105], v[88:89], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[106:107], v[90:91], %[v_acc_13] \n"
|
||||
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[108:109], v[92:93], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[110:111], v[94:95], %[v_acc_13] \n"
|
||||
|
||||
_UK_MFMA_ " %[v_acc_14], acc[112:113], v[64:65], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[114:115], v[66:67], %[v_acc_14] \n"
|
||||
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[116:117], v[68:69], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[118:119], v[70:71], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[120:121], v[72:73], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[122:123], v[74:75], %[v_acc_14] \n"
|
||||
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[124:125], v[76:77], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[126:127], v[78:79], %[v_acc_14] \n"
|
||||
|
||||
_UK_MFMA_ " %[v_acc_15], acc[112:113], v[80:81], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[114:115], v[82:83], %[v_acc_15] \n"
|
||||
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[116:117], v[84:85], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[118:119], v[86:87], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[120:121], v[88:89], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[122:123], v[90:91], %[v_acc_15] \n"
|
||||
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[20:23], 0 offen offset:3072\n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[124:125], v[92:93], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[126:127], v[94:95], %[v_acc_15] \n"
|
||||
////////////////////////////////////////////////////
|
||||
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
|
||||
" s_cbranch_scc0 L_end%= \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
|
||||
" s_cselect_b32 s86, %[s_tile_os_a], 0 \n"
|
||||
" s_add_u32 s16, s86, s16 \n"
|
||||
" s_addc_u32 s17, 0, s17 \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
|
||||
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
|
||||
" s_add_u32 s20, s86, s20 \n"
|
||||
" s_addc_u32 s21, 0, s21 \n"
|
||||
" ;------------------------------------------ \n"
|
||||
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
|
||||
" s_barrier \n"
|
||||
///////////////////////////////////////////////////////
|
||||
_UK_MFMA_ " %[v_acc_0], acc[128:129], v[96:97], %[v_acc_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[130:131], v[98:99], %[v_acc_0] \n"
|
||||
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[132:133], v[100:101], %[v_acc_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[134:135], v[102:103], %[v_acc_0] \n"
|
||||
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[136:137], v[104:105], %[v_acc_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[138:139], v[106:107], %[v_acc_0] \n"
|
||||
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[140:141], v[108:109], %[v_acc_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[142:143], v[110:111], %[v_acc_0] \n"
|
||||
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[128:129], v[112:113], %[v_acc_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[130:131], v[114:115], %[v_acc_1] \n"
|
||||
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[132:133], v[116:117], %[v_acc_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[134:135], v[118:119], %[v_acc_1] \n"
|
||||
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[136:137], v[120:121], %[v_acc_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[138:139], v[122:123], %[v_acc_1] \n"
|
||||
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[140:141], v[124:125], %[v_acc_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[142:143], v[126:127], %[v_acc_1] \n"
|
||||
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[144:145], v[96:97], %[v_acc_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[146:147], v[98:99], %[v_acc_2] \n"
|
||||
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[148:149], v[100:101], %[v_acc_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[150:151], v[102:103], %[v_acc_2] \n"
|
||||
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[152:153], v[104:105], %[v_acc_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[154:155], v[106:107], %[v_acc_2] \n"
|
||||
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[156:157], v[108:109], %[v_acc_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[158:159], v[110:111], %[v_acc_2] \n"
|
||||
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[144:145], v[112:113], %[v_acc_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[146:147], v[114:115], %[v_acc_3] \n"
|
||||
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[148:149], v[116:117], %[v_acc_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[150:151], v[118:119], %[v_acc_3] \n"
|
||||
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[152:153], v[120:121], %[v_acc_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[154:155], v[122:123], %[v_acc_3] \n"
|
||||
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[156:157], v[124:125], %[v_acc_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[158:159], v[126:127], %[v_acc_3] \n"
|
||||
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, 0, %[s_m0_init] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[160:161], v[96:97], %[v_acc_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[162:163], v[98:99], %[v_acc_4] \n"
|
||||
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[164:165], v[100:101], %[v_acc_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[166:167], v[102:103], %[v_acc_4] \n"
|
||||
" ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[168:169], v[104:105], %[v_acc_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[170:171], v[106:107], %[v_acc_4] \n"
|
||||
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[172:173], v[108:109], %[v_acc_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[174:175], v[110:111], %[v_acc_4] \n"
|
||||
" ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[160:161], v[112:113], %[v_acc_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[162:163], v[114:115], %[v_acc_5] \n"
|
||||
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[164:165], v[116:117], %[v_acc_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[166:167], v[118:119], %[v_acc_5] \n"
|
||||
" ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2] "
|
||||
"\n" _UK_MFMA_ " %[v_acc_5], acc[168:169], v[120:121], %[v_acc_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[170:171], v[122:123], %[v_acc_5] \n"
|
||||
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[172:173], v[124:125], %[v_acc_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[174:175], v[126:127], %[v_acc_5] \n"
|
||||
" ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3] "
|
||||
"\n" _UK_MFMA_ " %[v_acc_6], acc[176:177], v[96:97], %[v_acc_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[178:179], v[98:99], %[v_acc_6] \n"
|
||||
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[180:181], v[100:101], %[v_acc_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[182:183], v[102:103], %[v_acc_6] \n"
|
||||
" ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4] "
|
||||
"\n" _UK_MFMA_ " %[v_acc_6], acc[184:185], v[104:105], %[v_acc_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[186:187], v[106:107], %[v_acc_6] \n"
|
||||
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[188:189], v[108:109], %[v_acc_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[190:191], v[110:111], %[v_acc_6] \n"
|
||||
" ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5] "
|
||||
"\n" _UK_MFMA_ " %[v_acc_7], acc[176:177], v[112:113], %[v_acc_7] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[178:179], v[114:115], %[v_acc_7] \n"
|
||||
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[180:181], v[116:117], %[v_acc_7] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[182:183], v[118:119], %[v_acc_7] \n"
|
||||
" ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6] "
|
||||
"\n" _UK_MFMA_ " %[v_acc_7], acc[184:185], v[120:121], %[v_acc_7] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[186:187], v[122:123], %[v_acc_7] \n"
|
||||
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[188:189], v[124:125], %[v_acc_7] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[190:191], v[126:127], %[v_acc_7] \n"
|
||||
" ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[192:193], v[96:97], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[194:195], v[98:99], %[v_acc_8] \n"
|
||||
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[196:197], v[100:101], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[198:199], v[102:103], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[200:201], v[104:105], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[202:203], v[106:107], %[v_acc_8] \n"
|
||||
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[204:205], v[108:109], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[206:207], v[110:111], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[192:193], v[112:113], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[194:195], v[114:115], %[v_acc_9] \n"
|
||||
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[196:197], v[116:117], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[198:199], v[118:119], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[200:201], v[120:121], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[202:203], v[122:123], %[v_acc_9] \n"
|
||||
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[204:205], v[124:125], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[206:207], v[126:127], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[208:209], v[96:97], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[210:211], v[98:99], %[v_acc_10] \n"
|
||||
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[212:213], v[100:101], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[214:215], v[102:103], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[216:217], v[104:105], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[218:219], v[106:107], %[v_acc_10] \n"
|
||||
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[220:221], v[108:109], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[222:223], v[110:111], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[208:209], v[112:113], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[210:211], v[114:115], %[v_acc_11] \n"
|
||||
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[212:213], v[116:117], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[214:215], v[118:119], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[216:217], v[120:121], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[218:219], v[122:123], %[v_acc_11] \n"
|
||||
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[220:221], v[124:125], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[222:223], v[126:127], %[v_acc_11] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[224:225], v[96:97], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[226:227], v[98:99], %[v_acc_12] \n"
|
||||
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[228:229], v[100:101], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[230:231], v[102:103], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[232:233], v[104:105], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[234:235], v[106:107], %[v_acc_12] \n"
|
||||
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[236:237], v[108:109], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[238:239], v[110:111], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[224:225], v[112:113], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[226:227], v[114:115], %[v_acc_13] \n"
|
||||
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[228:229], v[116:117], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[230:231], v[118:119], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[232:233], v[120:121], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[234:235], v[122:123], %[v_acc_13] \n"
|
||||
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[236:237], v[124:125], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[238:239], v[126:127], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[240:241], v[96:97], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[242:243], v[98:99], %[v_acc_14] \n"
|
||||
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[244:245], v[100:101], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[246:247], v[102:103], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[248:249], v[104:105], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[250:251], v[106:107], %[v_acc_14] \n"
|
||||
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[252:253], v[108:109], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[254:255], v[110:111], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[240:241], v[112:113], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[242:243], v[114:115], %[v_acc_15] \n"
|
||||
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[244:245], v[116:117], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[246:247], v[118:119], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[248:249], v[120:121], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[250:251], v[122:123], %[v_acc_15] \n"
|
||||
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[252:253], v[124:125], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[254:255], v[126:127], %[v_acc_15] \n"
|
||||
|
||||
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
|
||||
" s_cbranch_scc0 L_end%= \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
|
||||
" s_cselect_b32 s86, %[s_tile_os_a], 0 \n"
|
||||
" s_add_u32 s16, s86, s16 \n"
|
||||
" s_addc_u32 s17, 0, s17 \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
|
||||
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
|
||||
" s_add_u32 s20, s86, s20 \n"
|
||||
" s_addc_u32 s21, 0, s21 \n"
|
||||
" s_branch L_start%= \n"
|
||||
"L_end%=: \n"
|
||||
" s_nop 2 \n"
|
||||
#endif
|
||||
|
||||
#undef _UK_MFMA_
|
||||
|
||||
17
include/ck_tile/ops/flatmm_uk.hpp
Normal file
17
include/ck_tile/ops/flatmm_uk.hpp
Normal file
@@ -0,0 +1,17 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_ff_32x512x128_1x4x1_16x16x32.hpp"
|
||||
#include "ck_tile/ops/fused_moe/kernel/flatmm_uk_kernel.hpp"
|
||||
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
|
||||
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline_policy.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
264
include/ck_tile/ops/fused_moe/kernel/flatmm_uk_kernel.hpp
Normal file
264
include/ck_tile/ops/fused_moe/kernel/flatmm_uk_kernel.hpp
Normal file
@@ -0,0 +1,264 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
// clang-format off
|
||||
// [indexing implementation-1]
|
||||
// using M_a as constexpr block_size to partition all tokens into different slices
|
||||
// each slice map to one expert, and one expert can have multiple slices
|
||||
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
|
||||
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
|
||||
// tok-0 tok-1 tok-2 tok-3 tok-4
|
||||
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
|
||||
//
|
||||
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
|
||||
// * this could be larger than actual, since actual tokens are on GPU
|
||||
//
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
|
||||
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
|
||||
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
|
||||
//
|
||||
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
|
||||
//
|
||||
// * Note on token_id_per_expert/sorted_token_ids_ptr data:
|
||||
// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
|
||||
// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from
|
||||
// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
|
||||
//
|
||||
// 32bit 0........23 24.....31 bit
|
||||
// (data) -> (token_id | topk_id)
|
||||
// low 24 bit is for token id, top 8 bit is for topk id
|
||||
//
|
||||
// the input after smooth-quant is [token, topk, hidden_dim], originally it is [token, hidden_dim]
|
||||
// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
|
||||
//
|
||||
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
|
||||
// * length is (max_num_tokens_padded + block_size - 1) / block_size
|
||||
//
|
||||
// num_tokens_post_padded_ptr : [28]
|
||||
// num_sorted_tiles_ptr : [7]
|
||||
//
|
||||
// * different from vLLM
|
||||
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
|
||||
// 2)need sorted_weight_ptr
|
||||
// 3) use num_sorted_tiles_ptr, already divided by M_a
|
||||
//
|
||||
// * below used for indexing
|
||||
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
|
||||
// 2) sorted_weight_ptr
|
||||
// 3) sorted_expert_ids_ptr
|
||||
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
|
||||
//
|
||||
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
|
||||
//
|
||||
// [indexing implementation-2]
|
||||
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
|
||||
// tok-0 tok-1 tok-2 tok-3 tok-4
|
||||
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
|
||||
//
|
||||
// we generate original rol/col id as
|
||||
// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
|
||||
// let x be one element of above, we can get:
|
||||
// tpok_row_id(token_id) = x % num_tokens(5)
|
||||
// tpok_col_id(expert_Id) = x / num_tokens
|
||||
// topk_row_id/col_id can be used to access original topk_ids/topk_weight
|
||||
//
|
||||
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// we can get permuted_rc_ids:
|
||||
// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
|
||||
//
|
||||
//
|
||||
// clang-format on
|
||||
//
|
||||
namespace ck_tile {
|
||||
|
||||
// m: num_tokens (or token*input-batch)
|
||||
// k: intermediate_size
|
||||
// n: intermediate_size used between 2 FC (TP slice this)
|
||||
// e: num expert
|
||||
// if doing pre-shuffle
|
||||
// nr : n / Block_Nr
|
||||
// kr : k / Block_Kr
|
||||
// w : fattened 1d wave buffer
|
||||
struct FlatmmUkHostArgs
|
||||
{
|
||||
const void* a_ptr; // [m, k], input token
|
||||
const void* b_ptr; // [m, k], input token
|
||||
const void* c_ptr; // [m, k], output token
|
||||
void* d_ptr; // [m, k], output token
|
||||
void* dbg_int_ptr; // [m, k], output token
|
||||
void* dbg_bf16_ptr; // [m, k], output token
|
||||
void* dbg_fp32_ptr; // [m, k], output token
|
||||
|
||||
index_t hidden_size; // K
|
||||
index_t intermediate_size; // N
|
||||
index_t num_tokens; // M
|
||||
|
||||
index_t num_experts; // number of groups
|
||||
index_t topk; // need this?
|
||||
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
|
||||
};
|
||||
|
||||
// This is scatter/gather b2b group-gemm
|
||||
template <typename Pipeline_, typename Epilogue_>
|
||||
struct FlatmmUkKernel
|
||||
{
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Epilogue = remove_cvref_t<Epilogue_>; // TODO: not used
|
||||
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
|
||||
// static_assert(kBlockPerCu > 0);
|
||||
|
||||
using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape
|
||||
static constexpr index_t BlockSize_ = BlockShape::BlockSize;
|
||||
|
||||
using ADataType = typename Pipeline::Problem::ADataType;
|
||||
using GDataType = typename Pipeline::Problem::GDataType;
|
||||
using DDataType = typename Pipeline::Problem::AccDataType;
|
||||
using AccDataType = typename Pipeline::Problem::AccDataType;
|
||||
using ODataType = typename Pipeline::Problem::ODataType;
|
||||
using AScaleDataType = typename Pipeline::Problem::AScaleDataType;
|
||||
using GScaleDataType = typename Pipeline::Problem::GScaleDataType;
|
||||
using DScaleDataType = typename Pipeline::Problem::DScaleDataType;
|
||||
using YSmoothScaleDataType = typename Pipeline::Problem::YSmoothScaleDataType;
|
||||
using TopkWeightDataType = typename Pipeline::Problem::TopkWeightDataType;
|
||||
using IndexDataType = typename Pipeline::Problem::IndexDataType;
|
||||
using YDataType = typename Pipeline::Problem::YDataType;
|
||||
|
||||
using Traits = typename Pipeline::Problem::Traits;
|
||||
static constexpr bool UseUK = true;
|
||||
|
||||
static constexpr bool IsGateOnly = Traits::IsGateOnly;
|
||||
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
|
||||
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
|
||||
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
|
||||
template <> struct t2s<fp16_t> { static constexpr const char * name = "fp16"; };
|
||||
template <> struct t2s<bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
template <> struct t2s<int8_t> { static constexpr const char * name = "int8"; };
|
||||
// clang-format on
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
// clang-format off
|
||||
using S_ = BlockShape;
|
||||
|
||||
auto prec_str = [&] () {
|
||||
std::string base_str = _SS_(t2s<ADataType>::name);
|
||||
if (!std::is_same_v<ADataType, GDataType>) {
|
||||
base_str += _SS_("_") + _SS_(t2s<GDataType>::name);
|
||||
}
|
||||
return base_str;
|
||||
}();
|
||||
|
||||
return _SS_("fused_moe_") + _SS_(prec_str) + "_" +
|
||||
_TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" +
|
||||
_TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" +
|
||||
_TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name);
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
struct FusedMoeGemmKargs
|
||||
{
|
||||
const void* a_ptr; // [m, k], input token
|
||||
const void* b_ptr; // [m, k], input token
|
||||
const void* c_ptr; // [m, k], output token
|
||||
void* d_ptr; // [m, k], output token
|
||||
void* dbg_int_ptr; // [m, k], output token
|
||||
void* dbg_bf16_ptr; // [m, k], output token
|
||||
void* dbg_fp32_ptr; // [m, k], output token
|
||||
|
||||
index_t hidden_size; // K
|
||||
index_t intermediate_size; // N
|
||||
index_t num_tokens; // M
|
||||
|
||||
index_t num_experts; // number of groups
|
||||
index_t topk; // need this?
|
||||
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
|
||||
};
|
||||
|
||||
// TODO: switch karg based on
|
||||
using Kargs = FusedMoeGemmKargs;
|
||||
using Hargs = FlatmmUkHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
|
||||
{
|
||||
// TODO: hargs/kargs not guranteed to be the same
|
||||
return bit_cast<Kargs>(hargs);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
|
||||
{
|
||||
index_t ms = ck_tile::integer_divide_ceil(hargs.num_tokens, BlockShape::Block_M0);
|
||||
index_t ns = ck_tile::integer_divide_ceil(hargs.intermediate_size, BlockShape::Block_N0);
|
||||
return dim3(ns, ms, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
#if 0
|
||||
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
|
||||
{
|
||||
printf("[KERNEL] FlatmmUkKernel =====\n");
|
||||
printf("[KERNEL] blockDim: [%d, %d], gridDim: [%d, %d]\n",
|
||||
static_cast<int>(blockDim.x),
|
||||
static_cast<int>(blockDim.y),
|
||||
static_cast<int>(gridDim.x),
|
||||
static_cast<int>(gridDim.y));
|
||||
printf("[KERNEL] lds = %.3f (KB)\n", GetSmemSize() / 1024.0f);
|
||||
}
|
||||
|
||||
[[maybe_unused]] uint32_t tidx = threadIdx.x; // 0~255
|
||||
[[maybe_unused]] uint32_t tidy = threadIdx.y; // 0~0
|
||||
[[maybe_unused]] uint32_t bidx = blockIdx.x; // 0~1
|
||||
[[maybe_unused]] uint32_t bidy = blockIdx.y; // 0~51
|
||||
[[maybe_unused]] uint32_t bdmx = blockDim.x; // 256
|
||||
[[maybe_unused]] uint32_t bdmy = blockDim.y; // 1
|
||||
[[maybe_unused]] uint32_t gdmx = gridDim.x; // 2
|
||||
[[maybe_unused]] uint32_t gdmy = gridDim.y; // 52
|
||||
[[maybe_unused]] uint32_t gid = ((bdmx * bdmy) * gdmx) * bidy
|
||||
+ (bdmx * bdmy) * bidx
|
||||
+ bdmx * tidy
|
||||
+ tidx;
|
||||
|
||||
[[maybe_unused]]int * dbg_int = static_cast<int*>(kargs.dbg_int_ptr);
|
||||
[[maybe_unused]]short * dbg_bf16 = static_cast<short*>(kargs.dbg_bf16_ptr);
|
||||
[[maybe_unused]]float * dbg_fp32 = static_cast<float*>(kargs.dbg_fp32_ptr);
|
||||
|
||||
dbg_int[gid] = -1;
|
||||
dbg_fp32[gid] = -1.0f;
|
||||
#endif
|
||||
|
||||
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
|
||||
|
||||
Pipeline{}(kargs, smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
337
include/ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline.hpp
Normal file
337
include/ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline.hpp
Normal file
@@ -0,0 +1,337 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/flatmm_uk_pipeline_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
|
||||
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
|
||||
|
||||
<----- gemm-N ------>
|
||||
+----+----+----+----+
|
||||
| w0 | w1 | w2 | w3 | gemm-m
|
||||
+----+----+----+----+
|
||||
*/
|
||||
template <typename Problem_, typename Policy_ = GemmPipelineFlatmmPolicy>
|
||||
struct GemmPipeline_FlatmmUk
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
|
||||
|
||||
using ADataType = typename Problem::ADataType;
|
||||
using GDataType = typename Problem::GDataType;
|
||||
using DDataType = typename Problem::AccDataType;
|
||||
using AccDataType = typename Problem::AccDataType;
|
||||
using ODataType = typename Problem::ODataType;
|
||||
using AScaleDataType = typename Problem::AScaleDataType;
|
||||
using GScaleDataType = typename Problem::GScaleDataType;
|
||||
using DScaleDataType = typename Problem::DScaleDataType;
|
||||
using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
|
||||
using TopkWeightDataType = typename Problem::TopkWeightDataType;
|
||||
using IndexDataType = typename Problem::IndexDataType;
|
||||
using YDataType = typename Problem::YDataType;
|
||||
|
||||
using Traits = typename Problem::Traits;
|
||||
|
||||
static constexpr bool IsGateOnly = Traits::IsGateOnly;
|
||||
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
|
||||
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
|
||||
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
|
||||
|
||||
static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
|
||||
static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
|
||||
static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
|
||||
static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
|
||||
|
||||
static constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
|
||||
static constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
|
||||
static constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
|
||||
static constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
// minimize occupancy
|
||||
return 2;
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "flatmm_uk";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
|
||||
constexpr index_t smem_bridge =
|
||||
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
|
||||
return max(smem_0, smem_bridge);
|
||||
}
|
||||
|
||||
// this is the thread-offset along row/col
|
||||
CK_TILE_HOST_DEVICE static auto GetACoord()
|
||||
{
|
||||
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
|
||||
const auto a_coord = a_dist.calculate_index();
|
||||
return a_coord;
|
||||
}
|
||||
|
||||
// this is the thread-offset along row/col
|
||||
CK_TILE_HOST_DEVICE static auto GetOCoord()
|
||||
{
|
||||
constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
|
||||
const auto o_coord = o_dist.calculate_index();
|
||||
return o_coord;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto GetNumRowCoords_A()
|
||||
{
|
||||
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
|
||||
constexpr index_t MLans = BlockShape::BlockSize / KLans;
|
||||
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
|
||||
|
||||
return MRepeat;
|
||||
}
|
||||
|
||||
// TODO: properlly support scatter/gather
|
||||
CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset)
|
||||
{
|
||||
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
|
||||
constexpr index_t MLans = BlockShape::BlockSize / KLans;
|
||||
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
|
||||
|
||||
auto base_coord = threadIdx.x / KLans + base_offset;
|
||||
|
||||
array<index_t, MRepeat> coords;
|
||||
static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
|
||||
|
||||
return coords;
|
||||
}
|
||||
CK_TILE_DEVICE auto GetRowCoords_O2(index_t base_offset)
|
||||
{
|
||||
constexpr index_t NLans = BlockShape::Block_N0 / kAlignmentO;
|
||||
constexpr index_t MLans = BlockShape::BlockSize / NLans;
|
||||
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
|
||||
|
||||
auto base_coord = threadIdx.x / NLans + base_offset;
|
||||
|
||||
array<index_t, MRepeat> coords;
|
||||
static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
|
||||
|
||||
return coords;
|
||||
}
|
||||
|
||||
template <typename ROW_COORDS>
|
||||
CK_TILE_DEVICE auto GetRowID(const ROW_COORDS coords, const IndexDataType* sorted_token_ids_ptr)
|
||||
{
|
||||
constexpr index_t n_size = coords.size();
|
||||
|
||||
array<index_t, n_size> row_ids;
|
||||
static_for<0, n_size, 1>{}([&](auto i) {
|
||||
row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
|
||||
});
|
||||
|
||||
return row_ids;
|
||||
}
|
||||
|
||||
template <typename ROW_COORDS>
|
||||
CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords,
|
||||
const TopkWeightDataType* sorted_weight_ptr)
|
||||
{
|
||||
constexpr index_t n_size = coords.size();
|
||||
|
||||
array<TopkWeightDataType, n_size> w;
|
||||
static_for<0, n_size, 1>{}([&](auto i) {
|
||||
w.at(i) = sorted_weight_ptr[coords[i]]; // base_coord + i * MLans;
|
||||
});
|
||||
|
||||
return w;
|
||||
}
|
||||
|
||||
// TODO: this row id is before shuffle atomic, need use acc distribution
|
||||
CK_TILE_DEVICE auto GetRowCoords_O(index_t base_offset)
|
||||
{
|
||||
constexpr index_t MLanes = BlockShape::Warp_M1;
|
||||
constexpr index_t Repeat_M = BlockShape::Repeat_M1;
|
||||
|
||||
auto base_coord = threadIdx.x % MLanes + base_offset;
|
||||
|
||||
array<index_t, Repeat_M> coords;
|
||||
static_for<0, Repeat_M, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLanes; });
|
||||
|
||||
return coords;
|
||||
}
|
||||
|
||||
template <typename Karg>
|
||||
CK_TILE_DEVICE auto operator()(const Karg& kargs, CK_TILE_LDS_ADDR void* smem)
|
||||
{
|
||||
#if 0
|
||||
if(threadIdx.x == 0 && blockIdx.x == 0 && threadIdx.y == 0 && blockIdx.y == 0)
|
||||
{
|
||||
printf("[PIPE] GemmPipeline_FlatmmUk =====\n");
|
||||
}
|
||||
|
||||
[[maybe_unused]] uint32_t tidx = threadIdx.x; // 0~255
|
||||
[[maybe_unused]] uint32_t tidy = threadIdx.y; // 0~0
|
||||
[[maybe_unused]] uint32_t bidx = blockIdx.x; // 0~1
|
||||
[[maybe_unused]] uint32_t bidy = blockIdx.y; // 0~51
|
||||
[[maybe_unused]] uint32_t bdmx = blockDim.x; // 256
|
||||
[[maybe_unused]] uint32_t bdmy = blockDim.y; // 1
|
||||
[[maybe_unused]] uint32_t gdmx = gridDim.x; // 2
|
||||
[[maybe_unused]] uint32_t gdmy = gridDim.y; // 52
|
||||
[[maybe_unused]] uint32_t gid = ((bdmx * bdmy) * gdmx) * bidy
|
||||
+ (bdmx * bdmy) * bidx
|
||||
+ bdmx * tidy
|
||||
+ tidx;
|
||||
#endif
|
||||
[[maybe_unused]] int* dbg_int = static_cast<int*>(kargs.dbg_int_ptr);
|
||||
[[maybe_unused]] short* dbg_bf16 = static_cast<short*>(kargs.dbg_bf16_ptr);
|
||||
[[maybe_unused]] float* dbg_fp32 = static_cast<float*>(kargs.dbg_fp32_ptr);
|
||||
|
||||
ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size; // N
|
||||
index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
|
||||
index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
|
||||
index_t interm_idx_nr0 = __builtin_amdgcn_readfirstlane(
|
||||
blockIdx.x * BlockShape::Block_Nr0); // intermediate_tile_id * Block_N / (N in W)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// a
|
||||
auto a_res =
|
||||
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
|
||||
kargs.num_tokens * kargs.hidden_size * sizeof(ADataType));
|
||||
auto row_ids_a = GetRowCoords_A(blockIdx.y * BlockShape::Block_M0);
|
||||
auto a_coords = generate_tuple(
|
||||
[&](auto i) {
|
||||
return row_ids_a[i] * kargs.hidden_size +
|
||||
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
|
||||
},
|
||||
number<row_ids_a.size()>{});
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// b
|
||||
auto b_win = [&]() {
|
||||
const GDataType* b_ptr = reinterpret_cast<const GDataType*>(kargs.b_ptr) +
|
||||
interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
|
||||
auto b_view_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
|
||||
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
|
||||
number<kAlignmentG>{},
|
||||
number<1>{});
|
||||
|
||||
auto b_window_ = make_tile_window_linear_raw(
|
||||
b_view_,
|
||||
make_tuple(number<BlockShape::Block_Nr0>{},
|
||||
number<BlockShape::Block_Kr0>{},
|
||||
number<BlockShape::Block_W0>{}),
|
||||
{0, 0, 0},
|
||||
Policy::template MakeGlobalTileDistribution_G<Problem>(),
|
||||
sequence<0, 1, 1>{});
|
||||
return b_window_;
|
||||
}();
|
||||
auto b_res = b_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
|
||||
auto b_coords = generate_tuple([&](auto i) { return b_win.cached_coords_[i].get_offset(); },
|
||||
number<decltype(b_win)::NumAccess_NonLinear>{});
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// core
|
||||
auto uk_0 = Policy::template GetUK_0<Problem>();
|
||||
auto acc_0 = uk_0(a_res,
|
||||
a_coords,
|
||||
b_res,
|
||||
b_coords,
|
||||
smem,
|
||||
kargs.hidden_size,
|
||||
BlockShape::Block_K0, // tile offset for B matrix each unroll
|
||||
BlockShape::Block_Kr0 *
|
||||
BlockShape::Block_W0, // tile offset for B matrix each unroll
|
||||
dbg_int,
|
||||
dbg_bf16,
|
||||
dbg_fp32);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
{
|
||||
int tid = threadIdx.x;
|
||||
float srdfp32 = 0.f;
|
||||
float* smemfp32 = static_cast<float*>(smem);
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// store to lds
|
||||
for(uint32_t accIdx = 0; accIdx < 16; accIdx++)
|
||||
{
|
||||
float* accSmem = smemfp32 + 4 * blockDim.x * accIdx;
|
||||
for(int xyzw = 0; xyzw < 4; xyzw++)
|
||||
{
|
||||
accSmem[tid * 4 + xyzw] = acc_0.get_thread_buffer()[accIdx * 4 + xyzw];
|
||||
}
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// read from lds
|
||||
int sldIdx = 0;
|
||||
// int MLn = 15;
|
||||
// int Nln = tid / MLn;
|
||||
int tidInWave = tid % 64;
|
||||
int waveId = tid / 64;
|
||||
// sldIdx = (tid64 % 16 * 16 + tid64 / 16) % 64
|
||||
// + tid / 64;
|
||||
sldIdx = (tidInWave % 16 * 16 + tidInWave / 16) + waveId * 4;
|
||||
|
||||
const int accNLane = 16;
|
||||
const int NLaneCnt = BlockShape::Block_N0 / 4; // xyzw 512 / 4 = 128
|
||||
const int accBlkSize = blockDim.x;
|
||||
|
||||
int accInnerId = tid % accNLane; // 0~15
|
||||
int accNIdx = tid / NLaneCnt; // 0~127 = 0; 128~255 = 1
|
||||
int acc01BlkIdx = tid % NLaneCnt / 16; // 0 ~ 7
|
||||
int accBlkIdx = acc01BlkIdx * 2; // 0, 2, 4, ..., 14
|
||||
int acc4Id = accBlkIdx * accBlkSize //
|
||||
+ accNIdx * accBlkSize + accInnerId * 16;
|
||||
sldIdx = acc4Id;
|
||||
|
||||
float* d_buf = static_cast<float*>(kargs.d_ptr);
|
||||
int c_blk_offset = blockIdx.y * BlockShape::Block_M0 * kargs.intermediate_size / 4 +
|
||||
blockIdx.x * BlockShape::Block_N0 / 4;
|
||||
|
||||
for(uint32_t accIdx = 0; accIdx < 16; accIdx++)
|
||||
{
|
||||
for(int xyzw = 0; xyzw < 4; xyzw++)
|
||||
{
|
||||
srdfp32 = smemfp32[accIdx * (1 * 4) + sldIdx * 4 + xyzw];
|
||||
acc_0.get_thread_buffer()[accIdx * 4 + xyzw] = srdfp32;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// store to vmem
|
||||
int c_m_idx_offset = (accIdx + accNIdx * 16) * kargs.intermediate_size / 4;
|
||||
int c_idx_offset = c_blk_offset + c_m_idx_offset + (tid % NLaneCnt);
|
||||
|
||||
for(int xyzw = 0; xyzw < 4; xyzw++)
|
||||
{
|
||||
srdfp32 = acc_0.get_thread_buffer()[accIdx * 4 + xyzw];
|
||||
d_buf[c_idx_offset * 4 + xyzw] = srdfp32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
// ----------------------------------------------------------------------------
|
||||
// debug
|
||||
for(uint32_t dbgi = 0; dbgi < 64; dbgi++)
|
||||
{
|
||||
dbg_fp32[gid * 64 + dbgi] = acc_0.get_thread_buffer()[dbgi];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,809 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
|
||||
#include "ck_tile/ops/flatmm_uk.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct GemmPipelineFlatmmPolicy
|
||||
{
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords()
|
||||
{
|
||||
// TODO: always 1 dword
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A()
|
||||
{
|
||||
// using async
|
||||
constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords();
|
||||
constexpr index_t data_bytes = sizeof(typename Problem::ADataType);
|
||||
static_assert(copy_bytes % data_bytes == 0);
|
||||
return copy_bytes / data_bytes;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_G()
|
||||
{
|
||||
constexpr index_t copy_bytes = [&]() { return 16; }();
|
||||
constexpr index_t data_bytes = sizeof(typename Problem::GDataType);
|
||||
static_assert(copy_bytes % data_bytes == 0);
|
||||
return copy_bytes / data_bytes;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_D()
|
||||
{
|
||||
constexpr index_t copy_bytes = [&]() { return 16; }();
|
||||
constexpr index_t data_bytes = sizeof(typename Problem::DDataType);
|
||||
static_assert(copy_bytes % data_bytes == 0);
|
||||
return copy_bytes / data_bytes;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_O()
|
||||
{
|
||||
if constexpr(Problem::Traits::OAtomic == 1)
|
||||
{
|
||||
// pack fp16/bf16 atomic
|
||||
static_assert(sizeof(typename Problem::ODataType) == 2);
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(Problem::Traits::OAtomic == 2)
|
||||
{
|
||||
// fp32 atomic
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 16 / sizeof(typename Problem::ODataType);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType_>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack()
|
||||
{
|
||||
// TODO: this is for 3d layout
|
||||
return 16 / sizeof(remove_cvref_t<DataType_>);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_A()
|
||||
{
|
||||
return GetSmemKPack<typename Problem::ADataType>();
|
||||
}
|
||||
|
||||
// used for bridge LDS shuffle
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_Y()
|
||||
{
|
||||
// TODO: this should match mfma layout
|
||||
return 16 / sizeof(typename Problem::YDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
|
||||
{
|
||||
constexpr auto a_sld_desc = MakeLdsLoadDesc_A<Problem>();
|
||||
constexpr auto a_sst_desc = MakeLdsStoreDesc_A<Problem>();
|
||||
static_assert(a_sld_desc.get_element_space_size() == a_sst_desc.get_element_space_size());
|
||||
return a_sld_desc.get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_Bridge()
|
||||
{
|
||||
constexpr auto bridge_sld_desc = MakeBridgeLdsLoadDesc<Problem>();
|
||||
constexpr auto bridge_sst_desc = MakeBridgeLdsStoreDesc<Problem>();
|
||||
static_assert(bridge_sld_desc.get_element_space_size() ==
|
||||
bridge_sst_desc.get_element_space_size());
|
||||
return bridge_sld_desc.get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t a_lds = GetSmemSize_A<Problem>();
|
||||
constexpr index_t bridge_lds = GetSmemSize_Bridge<Problem>();
|
||||
return max(a_lds, bridge_lds);
|
||||
}
|
||||
|
||||
template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK()
|
||||
{
|
||||
constexpr index_t K_vec = Alignment;
|
||||
constexpr index_t K_rem = KPerBlock / K_vec;
|
||||
|
||||
if constexpr(get_warp_size() < K_rem)
|
||||
{
|
||||
static_assert(K_rem % get_warp_size() == 0);
|
||||
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
|
||||
constexpr index_t K_wav = K_rem / get_warp_size();
|
||||
static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet");
|
||||
constexpr index_t M_wav = NumWarps / K_wav;
|
||||
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
|
||||
constexpr index_t M_rep = MPerBlock / M_wav;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
|
||||
tuple<sequence<1, 2>, sequence<2>>,
|
||||
tuple<sequence<1, 0>, sequence<1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K_lan = K_rem;
|
||||
constexpr index_t M_lan = get_warp_size() / K_lan;
|
||||
constexpr index_t M_wav = NumWarps;
|
||||
static_assert(MPerBlock % (M_lan * M_wav) == 0,
|
||||
"this tile size is too small please check");
|
||||
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<M_rep, M_wav, M_lan>, sequence<K_lan, K_vec>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
}
|
||||
|
||||
// optimized version for async, not same as simple MXK dist(pay attention!!)
|
||||
template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK_Async()
|
||||
{
|
||||
constexpr index_t K_vec = Alignment;
|
||||
constexpr index_t K_rem = KPerBlock / K_vec;
|
||||
|
||||
if constexpr(get_warp_size() <= K_rem)
|
||||
{
|
||||
static_assert(K_rem % get_warp_size() == 0);
|
||||
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
|
||||
constexpr index_t K_wav = K_rem / get_warp_size();
|
||||
static_assert(K_wav <= NumWarps, "do not support thread has repeat along K yet");
|
||||
constexpr index_t M_wav = NumWarps / K_wav;
|
||||
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
|
||||
constexpr index_t M_rep = MPerBlock / M_wav;
|
||||
// NOTE: no swap, but hard to avoid LDS bank conflict
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
|
||||
tuple<sequence<1, 2>, sequence<2>>,
|
||||
tuple<sequence<1, 0>, sequence<1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K_lan = K_rem;
|
||||
constexpr index_t M_lan = get_warp_size() / K_lan;
|
||||
constexpr index_t M_wav = NumWarps;
|
||||
static_assert(MPerBlock % (M_lan * M_wav) == 0,
|
||||
"this tile size is too small please check");
|
||||
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
|
||||
// NOTE: swapped for LDS load bank conflict free
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
// Note M_wave(num waves) is the fastest dim, different from sipmle 2d
|
||||
// distribution
|
||||
tuple<sequence<M_rep, M_lan, M_wav>, sequence<K_lan, K_vec>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<2>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t WarpPerBlock_N_,
|
||||
index_t WarpPerBlock_K_,
|
||||
index_t Repeat_N_,
|
||||
index_t Repeat_K_,
|
||||
index_t WarpSize_,
|
||||
index_t Alignment_>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_Nr_Kr_W()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Repeat_N_, WarpPerBlock_N_>,
|
||||
sequence<Repeat_K_, WarpPerBlock_K_>,
|
||||
sequence<WarpSize_, Alignment_>>,
|
||||
tuple<sequence<1, 2>, sequence<3>>,
|
||||
tuple<sequence<1, 1>, sequence<0>>,
|
||||
sequence<1, 2, 3>,
|
||||
sequence<0, 0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_A()
|
||||
{
|
||||
constexpr index_t Block_M_ = Problem::BlockShape::Block_M0;
|
||||
constexpr index_t Block_K_ = Problem::BlockShape::Block_K0;
|
||||
constexpr index_t NumWarps_ = Problem::BlockShape::NumWarps;
|
||||
constexpr index_t Alignment_ = GetAlignment_A<Problem>();
|
||||
return MakeGlobalTileDistribution_SimpleMxK_Async<Block_M_,
|
||||
Block_K_,
|
||||
NumWarps_,
|
||||
Alignment_>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
|
||||
{
|
||||
constexpr auto PermuteEnum = Problem::Traits::PermuteEnum;
|
||||
// constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2;
|
||||
using S_ = typename Problem::BlockShape;
|
||||
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
|
||||
{
|
||||
// number<S_::WarpPerBlock_N0>{}.rrr();
|
||||
// number<S_::Repeat_N0>{}.eee();
|
||||
return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N0,
|
||||
S_::WarpPerBlock_K0,
|
||||
S_::Repeat_N0, /// hidden_radio_0,
|
||||
S_::Repeat_K0,
|
||||
get_warp_size(),
|
||||
GetAlignment_G<Problem>()>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D()
|
||||
{
|
||||
constexpr auto PermuteEnum = Problem::Traits::PermuteEnum;
|
||||
using S_ = typename Problem::BlockShape;
|
||||
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
|
||||
{
|
||||
return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N1,
|
||||
S_::WarpPerBlock_K1,
|
||||
S_::Repeat_N1,
|
||||
S_::Repeat_K1,
|
||||
get_warp_size(),
|
||||
GetAlignment_D<Problem>()>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O()
|
||||
{
|
||||
using S_ = remove_cvref_t<typename Problem::BlockShape>;
|
||||
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
|
||||
// using CDataType = typename WarpGemm::CDataType;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
|
||||
sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
return c_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
|
||||
{
|
||||
// A async->LDS
|
||||
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
|
||||
constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
|
||||
constexpr index_t KPad = KPack; // pad between warps
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
if constexpr(LanesPerK >= warpSize)
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % warpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / warpSize;
|
||||
if constexpr(wavesPerK > NumWarps)
|
||||
{
|
||||
// TODO: need multiple issues along K to load all data
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t wavesPerM = NumWarps / wavesPerK;
|
||||
constexpr index_t NumIssues = Block_M / wavesPerM;
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<wavesPerM>{}, // m1
|
||||
number<wavesPerK>{}, // k0
|
||||
number<warpSize>{}, // k1
|
||||
number<KVector>{}), // k2
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // k0
|
||||
number<KVector>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<NumIssues>{}),
|
||||
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
|
||||
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
return lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(warpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
|
||||
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<LaneGroups>{}, // m1
|
||||
number<NumWarps>{}, // m2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<Block_K>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // m2
|
||||
number<KVector>{}, // k0
|
||||
number<1>{}), // k1
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<NumIssues>{}),
|
||||
make_pass_through_transform(number<NumWarps>{}),
|
||||
make_merge_transform(make_tuple(
|
||||
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
return lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
|
||||
{
|
||||
// A async->LDS
|
||||
// Note that, this descriptor is only to construct the layout inside LDS
|
||||
// in real Gemm pipeline, ds_read may not follow this pattern
|
||||
// (may follow that in tile_distribution)
|
||||
// below code is almost the same as SmemStore dist, with difference:
|
||||
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
|
||||
// 2). return discriptor is in NxK 2d layout
|
||||
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
|
||||
constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
|
||||
constexpr index_t KPad = KPack; // pad between warps
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
if constexpr(LanesPerK >= warpSize)
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % warpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / warpSize;
|
||||
if constexpr(wavesPerK >= NumWarps)
|
||||
{
|
||||
// TODO: need multiple issues along K to load all data
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t wavesPerM = NumWarps / wavesPerK;
|
||||
constexpr index_t NumIssues = Block_M / wavesPerM;
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<wavesPerM>{}, // m1
|
||||
number<wavesPerK>{}, // k0
|
||||
number<warpSize>{}, // k1
|
||||
number<KVector>{}), // k2
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // k0
|
||||
number<KVector>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KPack>{}, // lds load vector
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<NumIssues>{}, number<wavesPerM>{})),
|
||||
make_merge_transform(make_tuple(
|
||||
number<wavesPerK>{}, number<warpSize>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_desc_m_k;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(warpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
|
||||
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<LaneGroups>{}, // m1
|
||||
number<NumWarps>{}, // m2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<Block_K>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // m2
|
||||
number<KVector>{}, // k0
|
||||
number<1>{}), // k1
|
||||
number<KPack>{}, // lds load vector
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
|
||||
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_desc_m_k;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsLoadDesc()
|
||||
{
|
||||
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
|
||||
|
||||
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
|
||||
constexpr index_t KPad = 0; // pad between warps
|
||||
|
||||
constexpr auto desc =
|
||||
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
make_tuple(number<Block_N + KPad>{}, number<1>{}),
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreDesc()
|
||||
{
|
||||
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
|
||||
|
||||
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
|
||||
constexpr index_t KPad = 0; // KVector; // pad between warps
|
||||
|
||||
constexpr auto desc =
|
||||
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
make_tuple(number<Block_N + KPad>{}, number<1>{}),
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreForUKDesc()
|
||||
{
|
||||
constexpr index_t WarpPerBlock_N = Problem::BlockShape::WarpPerBlock_N0;
|
||||
constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N0;
|
||||
constexpr index_t Repeat_M = Problem::BlockShape::Repeat_M0;
|
||||
|
||||
constexpr index_t kAMLane = 16;
|
||||
constexpr index_t kABKLane = 4;
|
||||
constexpr index_t kABKPerLane = 4;
|
||||
|
||||
constexpr index_t KPack = kABKPerLane;
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<Repeat_M>{}, // m
|
||||
number<Repeat_N>{}, // n
|
||||
number<WarpPerBlock_N>{}, // n
|
||||
number<kABKLane>{}, // n
|
||||
number<kAMLane>{}, // m
|
||||
number<KPack>{}), // n
|
||||
make_tuple(number<Repeat_N * WarpPerBlock_N * kABKLane * kAMLane * KPack>{}, // m
|
||||
number<WarpPerBlock_N * kABKLane * kAMLane * KPack>{}, // n
|
||||
number<kABKLane * kAMLane * KPack>{}, // n
|
||||
number<kAMLane * KPack>{}, // n
|
||||
number<KPack>{}, // m
|
||||
number<1>{}), // n
|
||||
number<KPack>{}, // lds store vector(actually no explicit store)
|
||||
number<1>{});
|
||||
|
||||
constexpr auto desc = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(make_tuple(number<Repeat_M>{}, number<kAMLane>{})),
|
||||
make_merge_transform(make_tuple(number<Repeat_N>{},
|
||||
number<WarpPerBlock_N>{},
|
||||
number<kABKLane>{},
|
||||
number<KPack>{}))),
|
||||
make_tuple(sequence<0, 4>{}, sequence<1, 2, 3, 5>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm0()
|
||||
{
|
||||
using S_ = typename Problem::BlockShape;
|
||||
// A is vgpr, B is agpr. But since we transposed, so also need swap this
|
||||
// TODO: this is ugly
|
||||
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
|
||||
// TODO: ugly
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
|
||||
2>>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::int8_t> &&
|
||||
std::is_same_v<typename Problem::GDataType, ck_tile::int8_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<wg_ctrl>,
|
||||
2>>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_0()
|
||||
{
|
||||
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
|
||||
// the purpose is to hide thoes instructions under mfma
|
||||
// every value inside seq<...> is a mask, indicating a specific operation
|
||||
using S_ = typename Problem::BlockShape;
|
||||
constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
|
||||
constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
|
||||
constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
|
||||
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
|
||||
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
|
||||
S_::Block_N1 == 128)
|
||||
{
|
||||
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
|
||||
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
|
||||
// clang-format off
|
||||
constexpr auto seq_all =
|
||||
// 0 1 2 3 4 5 6 7
|
||||
sequence<GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 0
|
||||
GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 1
|
||||
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 2
|
||||
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 3
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 4
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 5
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 6
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 7
|
||||
return seq_all;
|
||||
// clang-format on
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
|
||||
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
|
||||
S_::Block_N1 == 128)
|
||||
{
|
||||
// Total 32 instructions, 16 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
|
||||
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
|
||||
// clang-format off
|
||||
constexpr auto seq_all =
|
||||
// 0 1 2 3 4 5 6 7
|
||||
sequence<GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 0
|
||||
GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 1
|
||||
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 2
|
||||
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A>{}; // 3
|
||||
return seq_all;
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_1()
|
||||
{
|
||||
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
|
||||
// the purpose is to hide thoes instructions under mfma
|
||||
// every value inside seq<...> is a mask, indicating a specific operation
|
||||
using S_ = typename Problem::BlockShape;
|
||||
constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
|
||||
constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
|
||||
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
|
||||
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
|
||||
S_::Block_N1 == 128)
|
||||
{
|
||||
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
|
||||
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
|
||||
// clang-format off
|
||||
constexpr auto seq_all =
|
||||
// 0 1 2 3 4 5 6 7
|
||||
sequence<GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 0
|
||||
GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 1
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 2
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 3
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 4
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 5
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 6
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 7
|
||||
return seq_all;
|
||||
// clang-format on
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
|
||||
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
|
||||
S_::Block_N1 == 128)
|
||||
{
|
||||
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
|
||||
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
|
||||
// clang-format off
|
||||
constexpr auto seq_all =
|
||||
// 0 1 2 3 4 5 6 7
|
||||
sequence<GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 0
|
||||
GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 1
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 2
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 3
|
||||
return seq_all;
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1()
|
||||
{
|
||||
using S_ = typename Problem::BlockShape;
|
||||
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
|
||||
// TODO: ugly
|
||||
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
|
||||
2>>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::int8_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::int8_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<wg_ctrl>,
|
||||
2>>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm0()
|
||||
{
|
||||
using S_ = remove_cvref_t<typename Problem::BlockShape>;
|
||||
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm0<Problem>())>;
|
||||
using CDataType = typename WarpGemm::CDataType;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<S_::Repeat_M0, S_::WarpPerBlock_M0>,
|
||||
sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm1()
|
||||
{
|
||||
using S_ = remove_cvref_t<typename Problem::BlockShape>;
|
||||
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
|
||||
using CDataType = typename WarpGemm::CDataType;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
|
||||
sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// this is used as A matrix for 2nd gemm
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution()
|
||||
{
|
||||
using S_ = remove_cvref_t<typename Problem::BlockShape>;
|
||||
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
|
||||
|
||||
// TODO: all waves a along different N, but same M
|
||||
constexpr auto y_outer_dstr_enc =
|
||||
tile_distribution_encoding<sequence<S_::WarpPerBlock_M1>,
|
||||
tuple<sequence<S_::Repeat_M1>, sequence<S_::Repeat_K1>>,
|
||||
tuple<sequence<0>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{});
|
||||
constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode);
|
||||
return y_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeYBlockTile()
|
||||
{
|
||||
constexpr auto y_block_dstr = MakeYTileDistribution<Problem>();
|
||||
auto y_block_tensor =
|
||||
make_static_distributed_tensor<typename Problem::YDataType>(y_block_dstr);
|
||||
return y_block_tensor;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetUK_0()
|
||||
{
|
||||
using S_ = typename Problem::BlockShape;
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
|
||||
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
|
||||
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
|
||||
{
|
||||
return Flatmm_ff_32x512x128_1x4x1_16x16x32_BF16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::fp16_t> &&
|
||||
std::is_same_v<typename Problem::GDataType, ck_tile::fp16_t> &&
|
||||
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
|
||||
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
|
||||
{
|
||||
return Flatmm_ff_32x512x128_1x4x1_16x16x32_FP16{};
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user