mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
add fp16 to test
This commit is contained in:
@@ -13,6 +13,22 @@
|
||||
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig;
|
||||
|
||||
template <typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, ST, SW, SQ, KW>
|
||||
{
|
||||
using ADataType = ck_tile::fp16_t;
|
||||
using GDataType = ck_tile::fp16_t;
|
||||
using DDataType = ck_tile::fp16_t;
|
||||
using AccDataType = float;
|
||||
using ODataType = ck_tile::fp16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
};
|
||||
|
||||
template <typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>
|
||||
{
|
||||
|
||||
@@ -19,7 +19,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
|
||||
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
|
||||
{
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 32, 128>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>;
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
@@ -19,8 +19,8 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0,
|
||||
typename Ts_::BlockTile_1,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0>;
|
||||
typename Ts_::WarpPerBlock_1,
|
||||
typename Ts_::WarpTile_1>;
|
||||
using f_problem =
|
||||
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
|
||||
typename Ts_::GDataType,
|
||||
|
||||
@@ -49,7 +49,7 @@ struct fmoe_ // traits, ugly name, only used for internal
|
||||
;
|
||||
|
||||
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_>;
|
||||
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpPerBlock_1 = ck_tile::sequence<1, 1, 4>;//ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
static constexpr ck_tile::index_t GateOnly = GateOnly_;
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
// clang-format off
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 32, 128>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -252,8 +252,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillUniformDistribution<TopkWeightDataType>{0.0f, 1.0f}(topk_weight_host);
|
||||
|
||||
// permute weight
|
||||
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
|
||||
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
|
||||
// ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
|
||||
// ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
|
||||
|
||||
// do moe sorting
|
||||
if(balance)
|
||||
@@ -287,7 +287,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// std::cout << num_sorted_tiles_host << std::endl;
|
||||
// output_matrix_3d(g_host, experts, shared_intermediate_size_0, hidden_size);
|
||||
std::cout << sorted_expert_ids_host << std::endl;
|
||||
// std::cout << topk_weight_host << std::endl;
|
||||
std::cout << topk_weight_host << std::endl;
|
||||
|
||||
// std::cout << sorted_weight_host << std::endl;
|
||||
|
||||
@@ -431,7 +431,7 @@ int main(int argc, char* argv[])
|
||||
// no dynamic quant case
|
||||
if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32")
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float>(
|
||||
return run<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float>(
|
||||
arg_parser)
|
||||
? 0
|
||||
: -2;
|
||||
|
||||
@@ -88,6 +88,32 @@ struct FusedMoeGemmPipeline_General
|
||||
const auto a_coord = a_dist.calculate_index();
|
||||
return a_coord;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE static void PrintMem(T& tensor)
|
||||
{
|
||||
constexpr auto spans = T::get_distributed_spans();
|
||||
int counter = 0;
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idxn) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idxk) {
|
||||
constexpr auto i_j_idx = make_tuple(idxn, idxk);
|
||||
const auto tile_idx =
|
||||
get_x_indices_from_distributed_indices(tensor.get_tile_distribution(), i_j_idx);
|
||||
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
{
|
||||
const auto row = tile_idx.at(number<0>{});
|
||||
const auto col = tile_idx.at(number<1>{});
|
||||
printf("in G row is %d , col is %d, counter is %d, value is: %f"
|
||||
" \n",
|
||||
row,
|
||||
col,
|
||||
counter,
|
||||
ck_tile::type_convert<float>(tensor(i_j_idx)));
|
||||
counter = counter + 1;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
template <typename AWindow, typename GWindow, typename DWindow, typename OWindow>
|
||||
CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
|
||||
const GWindow& g_window_,
|
||||
@@ -131,56 +157,13 @@ struct FusedMoeGemmPipeline_General
|
||||
auto a_dram_block = load_tile(a_global_to_dram_window);
|
||||
store_tile(a_lds_win, a_dram_block);
|
||||
#if 0
|
||||
{
|
||||
// check a matrix gather right or not
|
||||
constexpr auto a_spans = decltype(a_dram_block)::get_distributed_spans();
|
||||
int counter = 0;
|
||||
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
|
||||
sweep_tile_span(a_spans[number<1>{}], [&](auto idxk) {
|
||||
constexpr auto i_j_idx = make_tuple(idxm, idxk);
|
||||
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
{
|
||||
counter = counter + 1;
|
||||
index_t idm_0 = idxm.impl_.at(0);
|
||||
index_t idk_0 = idxk.impl_.at(0);
|
||||
printf("in A idm is %d , idk_ is %d , counter is %d, value is: %f \n",
|
||||
idm_0,
|
||||
idk_0,
|
||||
counter,
|
||||
ck_tile::type_convert<float>(a_dram_block(i_j_idx)));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
PrintMem(a_dram_block);
|
||||
#endif
|
||||
|
||||
auto g_dram_block = load_tile(g_global_to_dram_window);
|
||||
|
||||
#if 0
|
||||
{
|
||||
constexpr auto g_spans = decltype(g_dram_block)::get_distributed_spans();
|
||||
int counter = 0;
|
||||
sweep_tile_span(g_spans[number<0>{}], [&](auto idxn) {
|
||||
sweep_tile_span(g_spans[number<1>{}], [&](auto idxk) {
|
||||
constexpr auto i_j_idx = make_tuple(idxn, idxk);
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
g_dram_block.get_tile_distribution(), i_j_idx);
|
||||
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
{
|
||||
counter = counter + 1;
|
||||
const auto row = tile_idx.at(number<0>{});
|
||||
const auto col = tile_idx.at(number<1>{});
|
||||
printf("in G row is %d , col is %d, counter is %d, value is: %f"
|
||||
" \n",
|
||||
row,
|
||||
col,
|
||||
counter,
|
||||
ck_tile::type_convert<float>(g_dram_block(i_j_idx)));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
PrintMem(g_dram_block);
|
||||
#endif
|
||||
|
||||
clear_tile(s_acc); // initialize C
|
||||
@@ -215,32 +198,8 @@ struct FusedMoeGemmPipeline_General
|
||||
// activation(s_acc.get_thread_buffer()(i),s_acc.get_thread_buffer()[i]);
|
||||
// });
|
||||
tile_elementwise_inout(activation, s_acc, s_acc);
|
||||
#if 1
|
||||
{
|
||||
constexpr auto a_spans = decltype(s_acc)::get_distributed_spans();
|
||||
int counter = 0;
|
||||
// a_spans[0] = 1;
|
||||
sweep_tile_span(a_spans[number<0>{}], [&](auto idxm) {
|
||||
sweep_tile_span(a_spans[number<1>{}], [&](auto idxn) {
|
||||
constexpr auto i_j_idx = make_tuple(idxm, idxn);
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
g_dram_block.get_tile_distribution(), i_j_idx);
|
||||
if(threadIdx.x == 1 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
{
|
||||
counter = counter + 1;
|
||||
const auto row = tile_idx.at(number<0>{});
|
||||
const auto col = tile_idx.at(number<1>{});
|
||||
printf("in c row is %d , col is %d, counter is %d, value is: "
|
||||
"%f \n",
|
||||
row,
|
||||
col,
|
||||
counter,
|
||||
ck_tile::type_convert<float>(s_acc(i_j_idx)));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#if 0
|
||||
PrintMem(s_acc);
|
||||
#endif
|
||||
// move sacc to LDS
|
||||
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
|
||||
@@ -249,15 +208,30 @@ struct FusedMoeGemmPipeline_General
|
||||
make_tile_window(bridge_lds_view,
|
||||
Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(),
|
||||
{0, 0});
|
||||
|
||||
// cast data to YDataType
|
||||
auto y_pre = cast_tile<YDataType>(s_acc);
|
||||
// constexpr index_t thread_buffer_size = SaccBlockTileType::get_thread_buffer_size();
|
||||
// static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
// //y_pre.get_thread_buffer()(i) = type_convert<YDataType>(s_acc.get_thread_buffer()[i]);
|
||||
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
// {
|
||||
// printf("soure value: %f to value: %f\n",
|
||||
// s_acc.get_thread_buffer()[i],
|
||||
// type_convert<float>(y_pre.get_thread_buffer()[i]));
|
||||
// }
|
||||
// });
|
||||
|
||||
#if 1
|
||||
PrintMem(y_pre);
|
||||
#endif
|
||||
// save to lds
|
||||
store_tile(bridge_slds_win, y_pre);
|
||||
block_sync_lds();
|
||||
|
||||
// gemm down
|
||||
constexpr auto gemm_1 = Policy::template GetBlockGemm1<Problem>();
|
||||
using SaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
auto o_acc = SaccBlockTileType{};
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
auto o_acc = OaccBlockTileType{};
|
||||
// y data
|
||||
auto bridge_llds_win =
|
||||
make_tile_window(bridge_lds_view,
|
||||
@@ -265,6 +239,7 @@ struct FusedMoeGemmPipeline_General
|
||||
{0, 0},
|
||||
Policy::template MakeYTileDistribution<Problem>());
|
||||
auto y = load_tile(bridge_llds_win);
|
||||
|
||||
// d data
|
||||
auto d_global_to_dram_window = make_tile_window(
|
||||
d_window_.get_bottom_tensor_view(),
|
||||
@@ -278,6 +253,7 @@ struct FusedMoeGemmPipeline_General
|
||||
index_t iCounter1 = n1_loops - 1;
|
||||
while(iCounter1 > 0)
|
||||
{
|
||||
clear_tile(o_acc);
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc, y, d);
|
||||
block_sync_lds();
|
||||
@@ -292,9 +268,16 @@ struct FusedMoeGemmPipeline_General
|
||||
}
|
||||
// tail
|
||||
{
|
||||
clear_tile(o_acc);
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc, y, d);
|
||||
|
||||
auto o = cast_tile<ODataType>(o_acc);
|
||||
store_tile(o_window_, o);
|
||||
}
|
||||
#if 0
|
||||
PrintMem(o_acc);
|
||||
#endif
|
||||
// store_tile(o_window_, a_dram_block);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -230,7 +230,28 @@ struct FusedMoeGemmPipelineGeneralPolicy
|
||||
typename S_::WarpPerBlock_1,
|
||||
decltype(warp_gemm)>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
// this is used as A matrix for 2nd gemm
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution()
|
||||
{
|
||||
using S_ = remove_cvref_t<typename Problem::BlockShape>;
|
||||
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
|
||||
|
||||
constexpr auto y_outer_dstr_enc = tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>, sequence<S_::WarpPerBlock_K1, S_::Repeat_K1>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{};
|
||||
|
||||
constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{});
|
||||
constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode);
|
||||
return y_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -240,12 +261,12 @@ struct FusedMoeGemmPipelineGeneralPolicy
|
||||
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
|
||||
|
||||
constexpr auto d_outer_dstr_enc = tile_distribution_encoding<
|
||||
sequence<S_::WarpPerBlock_M1>,
|
||||
tuple<sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>, sequence<S_::Repeat_K1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1>,
|
||||
tuple<sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>, sequence<S_::WarpPerBlock_K1, S_::Repeat_K1>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
sequence<0, 1>>{};
|
||||
|
||||
constexpr auto d_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
d_outer_dstr_enc, typename WarpGemm::BWarpDstrEncoding{});
|
||||
@@ -326,7 +347,15 @@ struct FusedMoeGemmPipelineGeneralPolicy
|
||||
// TODO: this is ugly
|
||||
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
|
||||
// TODO: ugly
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::fp16_t> &&
|
||||
std::is_same_v<typename Problem::GDataType, ck_tile::fp16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<wg_ctrl>,
|
||||
1>>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
|
||||
{
|
||||
@@ -358,7 +387,15 @@ struct FusedMoeGemmPipelineGeneralPolicy
|
||||
using S_ = typename Problem::BlockShape;
|
||||
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
|
||||
// TODO: ugly
|
||||
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
|
||||
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<wg_ctrl>,
|
||||
1>>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
|
||||
{
|
||||
@@ -383,27 +420,5 @@ struct FusedMoeGemmPipelineGeneralPolicy
|
||||
2>>{};
|
||||
}
|
||||
}
|
||||
|
||||
// this is used as A matrix for 2nd gemm
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution()
|
||||
{
|
||||
using S_ = remove_cvref_t<typename Problem::BlockShape>;
|
||||
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
|
||||
|
||||
// TODO: all waves a along different N, but same M
|
||||
constexpr auto y_outer_dstr_enc = tile_distribution_encoding<
|
||||
sequence<S_::WarpPerBlock_N1>,
|
||||
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>, sequence<S_::Repeat_K1>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{});
|
||||
constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode);
|
||||
return y_block_dstr;
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
202
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
Normal file
202
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
Normal file
@@ -0,0 +1,202 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block distributed tensor
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
|
||||
struct BlockGemmARegBRegCRegV2
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensor& a_block_tensor,
|
||||
const BBlockTensor& b_block_tensor) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
// M->N Warp
|
||||
// constexpr auto a_block_outer_dstr_encoding =
|
||||
// tile_distribution_encoding<sequence<NWarp>,
|
||||
// tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
// tuple<sequence<1, 0>>,
|
||||
// tuple<sequence<1, 0>>,
|
||||
// sequence<1, 2>,
|
||||
// sequence<0, 0>>{};
|
||||
|
||||
// constexpr auto b_block_outer_dstr_encoding =
|
||||
// tile_distribution_encoding<sequence<MWarp>,
|
||||
// tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
// tuple<sequence<0, 1>>,
|
||||
// tuple<sequence<0, 1>>,
|
||||
// sequence<1, 2>,
|
||||
// sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
// constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
// a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
// constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
// b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
// check ABC-block-distribution
|
||||
// static_assert(
|
||||
// std::is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
|
||||
// remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
|
||||
// .get_static_tile_distribution_encoding())>>,
|
||||
// "A distribution is wrong!");
|
||||
// static_assert(
|
||||
// std::is_same_v<remove_cvref_t<decltype(b_block_dstr_encode)>,
|
||||
// remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
|
||||
// .get_static_tile_distribution_encoding())>>,
|
||||
// "B distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"C distribution is wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using BWarpDstr = typename WG::BWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using BWarpTensor = typename WG::BWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensor, typename BBlockTensor>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
|
||||
const BBlockTensor& b_block_tensor) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor, b_block_tensor);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user