[CK_TILE] fused-moe first version (#1634)

* moe pipeline

* update code

* compile OK

* update

* update cpu reference

* update pipeline_gemm0

* compiler ok

* update pipeline

* rename to ex pipeline

* block-asm

* update

* update

* update first gemm ok

* compute correct

* update file structure

* update README

* update

* update

* update code

* update API

* return unsupport case

* add comment

* update readme

* update

* uncomment

* update

* fix build err

---------

Co-authored-by: valarLip <340077269@qq.com>
This commit is contained in:
carlushuang
2024-11-26 11:14:56 +08:00
committed by GitHub
parent 645fe812f6
commit 440e28b08f
66 changed files with 8066 additions and 308 deletions

View File

@@ -40,7 +40,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
@@ -83,7 +83,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;

View File

@@ -42,8 +42,8 @@ enum class matrix_core_permute_style
{
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
};
// assume this is B matrix, originally we have batch*n*k
@@ -203,7 +203,7 @@ struct matrix_core_swizzle_kernel
else
{
// clang-format off
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
// b_nr_kr_kw_nw_kv or b_nr_kr_waveflatten
constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
@@ -332,7 +332,7 @@ struct matrix_core_swizzle_kernel
make_tuple(sequence<0>{}, sequence<1>{}));
return tmp_1;
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
@@ -376,13 +376,13 @@ struct matrix_core_swizzle_kernel
else
{
#if MERGE_2D_013425
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
return make_tile_window(dst_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{i_n * NPerBlock, i_k * KPerBlock},
get_dst_dist());
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;

View File

@@ -264,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
{
// permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
// b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
matrix_core_swizzle_traits t;
t.data_type = data_type;
t.permute = arg_parser.get_str("perm");

View File

@@ -5,7 +5,7 @@
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/moe_sorting.hpp"
#include "ck_tile/ops/fused_moe.hpp"
struct moe_sorting_trait
{

View File

@@ -0,0 +1,19 @@
set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding ${TILE_EXAPMLE_FUSED_MOE}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp)
target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS})
set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4) # rta
# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1)
# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS})

View File

@@ -0,0 +1,69 @@
# fused-moe
Implementing the fused-moe block operator using ck-tile. This is a scatter/gather-group-gemm based solution, similiar to that of [vllm moe](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), but we introduce more kernel fusion to boost performance
![](misc/moe-0.png)
The benifit of this fused-moe:
* 1.5~2x perf boost compared with current vllm solution
* zero workspace to reduce memory footprint
* much less kernel instance, easy to maintain
# Implementation and feature support
## moe-sorting
this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic)
## moe-gemm
`moe-gemm` is a group-gemm based back-to-back gemm, where the row-id of input token comes from another buffer. Naive understanding of fused-moe is from token-by-token view as below picture:
![](misc/moe-1.png)
After `moe-sorting`, we can view this algorithm as expert-by-expert, as below:
![](misc/moe-2.png)
## optimization
summary of the key design of this fused-moe operator:
* fuse 2 group-gemm + activation + `topk-weight` multiply into single kernel, using atomic for 2nd gemm accumualation
* fuse buffer-zeroing in `moe-sorgin`, user no longer need call extra torch.zero() for the out buffer
* fused scatter-gather for row index(same as vllm)
* pre-shuffle B matric(weight) to maximize memory throughput. input(activation) keep original layout `[batch, hidden]`.
* extrem optimized pipeline using block-inline-asm(we call it `micro-kernel` or `uk`), while not breaking the *composable* design of ck
##
```
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
```

View File

@@ -0,0 +1,52 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "fused_moesorting.hpp"
#include "fused_moegemm.hpp"
struct fused_moe_args
{
const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token (no need to do zeroing)
const void* topk_ids_ptr; // [tokens, topk]
const void* topk_weight_ptr; // [tokens, topk]
void* sorted_token_ids_ptr; // [max_num_tokens_padded]
void* sorted_weight_ptr; // [max_num_tokens_padded]
void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
void* num_sorted_tiles_ptr; // [1]
ck_tile::index_t block_m; // block_m, used to devide the input
ck_tile::index_t hidden_size; // k
ck_tile::index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
ck_tile::index_t num_tokens; // input number of tokens for current iteration
ck_tile::index_t num_experts; // number of groups
ck_tile::index_t topk; // need this?
ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
// This is the public API, will be generated by script
struct fused_moe_traits
{
std::string prec_i; // input precision
std::string prec_w; // weight precision
std::string prec_o; // output precision
std::string prec_st; // token scale data type
std::string prec_sw; // weight scale data type
std::string prec_sq; // smooth quant scale
std::string prec_kw; // topk-weight data type
int block_m;
int gate_only;
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&);

View File

@@ -0,0 +1,84 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include <string>
// this is only a convenient structure for creating an example
// this is not part of the host API
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig;
template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>
{
using ADataType = ck_tile::bf16_t;
using GDataType = ck_tile::bf16_t;
using DDataType = ck_tile::bf16_t;
using AccDataType = float;
using ODataType = ck_tile::bf16_t;
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
using IndexDataType = ck_tile::index_t;
};
template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, ST, SW, SQ, KW>
{
using ADataType = ck_tile::fp16_t;
using GDataType = ck_tile::fp16_t;
using DDataType = ck_tile::fp16_t;
using AccDataType = float;
using ODataType = ck_tile::fp16_t;
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
using IndexDataType = ck_tile::index_t;
};
template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>
{
using ADataType = ck_tile::int8_t;
using GDataType = ck_tile::int8_t;
using DDataType = ck_tile::int8_t;
using AccDataType = int32_t;
using ODataType = ck_tile::bf16_t;
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
using IndexDataType = ck_tile::index_t;
};
// runtime args
struct fused_moegemm_args : public ck_tile::FusedMoeGemmHostArgs
{
};
// This is the public API, will be generated by script
struct fused_moegemm_traits
{
std::string prec_i; // input precision
std::string prec_w; // weight precision
std::string prec_o; // output precision
std::string prec_st; // token scale data type
std::string prec_sw; // weight scale data type
std::string prec_sq; // smooth quant scale
std::string prec_kw; // topk-weight data type
int block_m;
int gate_only;
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
float fused_moegemm(fused_moegemm_traits, fused_moegemm_args, const ck_tile::stream_config&);

View File

@@ -0,0 +1,20 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/fused_moe.hpp"
struct fused_moesorting_trait
{
std::string index_type;
std::string weight_type; // currently always float
};
struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs
{
};
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s);

View File

@@ -0,0 +1,80 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fused_moe.hpp"
float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s)
{
auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1};
auto o_data_bytes = [&]() {
if(t.prec_o == "fp32")
return 4;
else if(t.prec_o == "fp16" || t.prec_o == "bf16")
return 2;
else if(t.prec_o == "int8" || t.prec_o == "fp8")
return 1;
return 1;
}();
auto t0 = fused_moesorting_trait{"int32", "fp32"};
auto a0 = fused_moesorting_args{
a.topk_ids_ptr, // const void* p_topk_ids;
a.topk_weight_ptr, // const void* p_weights;
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
a.sorted_weight_ptr, // void* p_sorted_weights;
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
a.o_ptr, // void* p_moe_buf;
a.num_tokens, // index_t tokens;
a.block_m, // index_t unit_size;
a.num_experts, // index_t num_experts;
a.topk, // index_t topk;
a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes;
};
auto t1 = fused_moegemm_traits{t.prec_i,
t.prec_w,
t.prec_o,
t.prec_st,
t.prec_sw,
t.prec_sq,
t.prec_kw,
t.block_m,
t.gate_only,
t.fused_quant};
auto a1 = fused_moegemm_args{
a.a_ptr, // const void* a_ptr;
a.a_scale_ptr, // const void* a_scale_ptr;
a.g_ptr, // const void* g_ptr;
a.d_ptr, // const void* d_ptr;
a.g_scale_ptr, // const void* g_scale_ptr;
a.d_scale_ptr, // const void* d_scale_ptr;
a.y_smooth_scale_ptr, // const void* y_smooth_scale_ptr;
a.o_ptr, // void* o_ptr;
a.sorted_token_ids_ptr, // const void* sorted_token_ids_ptr;
a.sorted_weight_ptr, // const void* sorted_weight_ptr;
a.sorted_expert_ids_ptr, // const void* sorted_expert_ids_ptr;
a.num_sorted_tiles_ptr, // const void* num_sorted_tiles_ptr;
a.hidden_size, // index_t hidden_size;
a.intermediate_size, // index_t intermediate_size;
a.num_tokens, // index_t num_tokens;
a.num_experts, // index_t num_experts;
a.topk, // index_t topk;
a.stride_token // index_t stride_token;
};
float r0 = -1;
float r1 = -1;
float r = ck_tile::launch_kernel(
s,
[=, &r0](const ck_tile::stream_config&) { r0 = fused_moesorting(t0, a0, s_sub); },
[=, &r1](const ck_tile::stream_config&) { r1 = fused_moegemm(t1, a1, s_sub); });
// keep unsupported case return negative
if(r0 < 0 || r1 < 0)
return -1;
return r;
}

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
// Note: this internal API only declare, not define here, otherwise will block `make -j`
template <typename Traits_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a);
template <ck_tile::index_t... Is>
using S = ck_tile::sequence<Is...>;
float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s)
{
// clang-format off
float r = -1;
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r = fused_moegemm_<t_>(s, a);
}
// clang-format on
return r;
}

View File

@@ -0,0 +1,60 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include <iostream>
template <ck_tile::index_t... Is>
using S = ck_tile::sequence<Is...>;
// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j
template <typename Ts_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
{
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0,
typename Ts_::BlockTile_1,
typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0>;
using f_problem =
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
typename Ts_::GDataType,
typename Ts_::DDataType,
typename Ts_::AccDataType,
typename Ts_::ODataType,
typename Ts_::AScaleDataType,
typename Ts_::GScaleDataType,
typename Ts_::DScaleDataType,
typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType,
ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded
f_shape,
f_traits>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
const dim3 grids = f_kernel::GridSize(a);
constexpr dim3 blocks = f_kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
static int printed = 0;
auto kargs = f_kernel::MakeKargs(a);
if(s.log_level_ > 0 && printed == 0)
{
std::cout << ", " << f_kernel::GetName() << std::flush;
printed = 1;
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs));
}

View File

@@ -0,0 +1,53 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp>
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename I,
typename W,
typename O,
typename ST,
typename SW,
typename SQ,
typename KW,
typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down>
typename WarpPerBlock_,
typename WarpTile_, // seq<*,*,*>, used to select mfma
ck_tile::index_t GateOnly_ = 0,
ck_tile::index_t FusedQuant_ = 0>
struct fmoe_ // traits, ugly name, only used for internal
{
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = ck_tile::remove_cvref_t<typename TypeConfig::ADataType>;
using GDataType = ck_tile::remove_cvref_t<typename TypeConfig::GDataType>;
using DDataType = ck_tile::remove_cvref_t<typename TypeConfig::DDataType>;
using AccDataType = ck_tile::remove_cvref_t<typename TypeConfig::AccDataType>;
using ODataType = ck_tile::remove_cvref_t<typename TypeConfig::ODataType>;
using AScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::AScaleDataType>;
using GScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::GScaleDataType>;
using DScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::DScaleDataType>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>;
using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>;
static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token
static constexpr ck_tile::index_t BI_ =
BlockTIle_::at(ck_tile::number<1>{}); // block intermediate
static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden
static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>;
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>;
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
static constexpr ck_tile::index_t GateOnly = GateOnly_;
static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
};

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "fused_moesorting.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
{
if(t.weight_type == "fp32" && t.index_type == "int32")
{
if(a.num_experts > 127)
{
printf("lds size exceed, only support experts <127 \n");
return -1;
}
if(a.moe_buf_bytes % 16)
{
printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes);
return -1;
}
using index_t = ck_tile::index_t;
using ms_weight_type = float;
index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64);
switch(smem_io_unroll_num)
{
case(1): {
MOE_SORTING_DISPATCH(1);
}
case(2): {
MOE_SORTING_DISPATCH(2);
}
case(3): {
MOE_SORTING_DISPATCH(3);
}
case(5): {
MOE_SORTING_DISPATCH(5);
}
case(6): {
MOE_SORTING_DISPATCH(6);
}
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): {
MOE_SORTING_DISPATCH(8);
}
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): {
MOE_SORTING_DISPATCH(10);
}
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: {
MOE_SORTING_DISPATCH(4);
}
}
}
return -1;
}

View File

@@ -0,0 +1,603 @@
#include <algorithm>
#include <cstring>
#include <unordered_set>
#include <vector>
#include <set>
#include "ck_tile/host.hpp"
#include "fused_moe.hpp"
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
// mfma_type, 0:32x32, 1:16x16
// TODO: padding?
template <typename T>
auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0)
{
assert(t.get_lengths().size() == 3);
int b_ = t.get_lengths()[0];
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[2];
if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 16, 2, 8});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 32, 4, 8});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 32, 2, 16});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 64, 4, 16});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
return t;
}
template <typename IndexType>
void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
{
size_t total_size = topk * tokens;
std::srand(seed);
std::set<IndexType> unique_set;
IndexType current_v;
for(size_t i = 0; i < total_size; i++)
{
if(i % topk == 0)
{
unique_set.clear();
}
current_v = std::rand() % num_expert;
while(unique_set.find(current_v) != unique_set.end())
{
current_v = std::rand() % num_expert;
}
unique_set.insert(current_v);
host_tensor[i] = current_v;
}
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("t", "128", "num input tokens")
.insert("e", "32", "num of experts")
.insert("k", "5", "topk")
.insert("h", "8192", "hidden_size of this model")
.insert("i", "8192", "intermediate_size between 2 gemms of FFN")
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
.insert("bm", "32", "blocking factor for sorted tokens")
.insert("tp", "8", "tensor parallel size")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec_i", "bf16", "input precision")
.insert("prec_w", "bf16", "weight precision")
.insert("prec_o", "bf16", "output precision")
.insert("prec_st", "auto", "token scale data type. auto will set to fp32")
.insert("prec_sw", "auto", "weight scale data type. auto will set to fp32")
.insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32")
.insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert(
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
.insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm")
.insert("balance",
"0",
"if set to 1, will try balance the expert in topk-ids(convenient for testing)")
.insert("init",
"2",
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
"normalized(slow)")
.insert("seed", "11939", "seed used to do random")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type,
// SQ:smooth-quant-type, KW:topk-weight-type
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t tokens = arg_parser.get_int("t");
ck_tile::index_t experts = arg_parser.get_int("e");
ck_tile::index_t topk = arg_parser.get_int("k");
ck_tile::index_t hidden_size = arg_parser.get_int("h");
ck_tile::index_t intermediate_size = arg_parser.get_int("i");
ck_tile::index_t stride = arg_parser.get_int("stride");
ck_tile::index_t block_m = arg_parser.get_int("bm");
if(stride < 0)
stride = hidden_size;
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_w = arg_parser.get_str("prec_w");
std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_st = arg_parser.get_str("prec_st");
std::string prec_sw = arg_parser.get_str("prec_sw");
std::string prec_sq = arg_parser.get_str("prec_sq");
std::string prec_kw = arg_parser.get_str("prec_kw");
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int fused_quant = arg_parser.get_int("fquant");
int gate_only = arg_parser.get_int("gate_only");
int api = arg_parser.get_int("api");
int balance = arg_parser.get_int("balance");
int tp = arg_parser.get_int("tp");
int init = arg_parser.get_int("init");
uint32_t seed = arg_parser.get_uint32("seed");
// w0 (Gate+Up or Gate only, N size)
ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp;
// w1 (Down, N size)
ck_tile::index_t shared_intermediate_size_1 = intermediate_size / tp;
auto prec_str = [&]() {
auto base_str = prec_i;
if(prec_i != prec_w)
base_str += "x" + prec_w;
if(prec_i != prec_o)
base_str += "=" + prec_o;
if(fused_quant != 0)
{
base_str += std::string("(") + prec_st + "|" + prec_sw + "|" + prec_sq + ")";
}
return base_str;
}();
auto api_str = [&]() {
if(api == 0)
return std::string("fmoe");
else if(api == 1)
return std::string("moeg");
else if(api == 2)
return std::string("moes");
return std::string("");
}();
auto stride_str = [&]() {
if(stride == hidden_size)
return std::string("");
else
return std::string(", st:") + std::to_string(stride);
}();
std::cout << "[" << api_str << "|" << prec_str << "]"
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
<< ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
<< ", go:" << gate_only << ", q:" << fused_quant << std::flush;
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = typename TypeConfig::ADataType;
using GDataType = typename TypeConfig::GDataType;
using DDataType = typename TypeConfig::DDataType;
using AccDataType = typename TypeConfig::AccDataType;
using ODataType = typename TypeConfig::ODataType;
using AScaleDataType = typename TypeConfig::AScaleDataType;
using GScaleDataType = typename TypeConfig::GScaleDataType;
using DScaleDataType = typename TypeConfig::DScaleDataType;
using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType;
using TopkWeightDataType = typename TypeConfig::TopkWeightDataType;
using IndexDataType = typename TypeConfig::IndexDataType;
// host verify
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size});
ck_tile::HostTensor<DDataType> d_host({experts, hidden_size, shared_intermediate_size_1});
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size_0});
ck_tile::HostTensor<DScaleDataType> sd_host({shared_intermediate_size_1});
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant
ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort
ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort
int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded});
ck_tile::HostTensor<TopkWeightDataType> sorted_weight_host({max_num_tokens_padded});
ck_tile::HostTensor<IndexDataType> sorted_expert_ids_host(
{(max_num_tokens_padded + block_m - 1) / block_m});
ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1});
if(init == 0)
{
ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host);
ck_tile::FillStepRange<GDataType>{-.5f, .5f, 0.01f}(g_host);
ck_tile::FillStepRange<DDataType, false>{.5f, -.5f, -0.01f}(d_host);
ck_tile::FillStepRange<AScaleDataType>{0.f, 1.f, 0.01f}(sa_host);
ck_tile::FillStepRange<GScaleDataType>{0.f, 1.f, 0.01f}(sg_host);
ck_tile::FillStepRange<DScaleDataType>{0.f, 1.f, 0.01f}(sd_host);
ck_tile::FillStepRange<YSmoothScaleDataType>{0.f, 1.f, 0.01f}(sy_host);
ck_tile::FillStepRange<TopkWeightDataType>{-.5f, .5f, 0.01f}(topk_weight_host);
}
else if(init == 1)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f, seed, true}(a_host);
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f, seed, true}(g_host);
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f, seed, true}(d_host);
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed, true}(sa_host);
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host);
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host);
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f, seed, true}(sy_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f, seed, true}(
topk_weight_host);
}
else if(init == 2)
{
ck_tile::FillNormalDistribution<ADataType>{0.f, 1.f, seed, true}(a_host);
ck_tile::FillNormalDistribution<GDataType>{0.f, 1.f, seed, true}(g_host);
ck_tile::FillNormalDistribution<DDataType>{0.f, 1.f, seed, true}(d_host);
ck_tile::FillNormalDistribution<AScaleDataType>{0.f, 1.f, seed, true}(sa_host);
ck_tile::FillNormalDistribution<GScaleDataType>{0.f, 1.f, seed, true}(sg_host);
ck_tile::FillNormalDistribution<DScaleDataType>{0.f, 1.f, seed, true}(sd_host);
ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host);
ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host);
}
// permute weight
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
// do moe sorting
if(balance)
{
int e_cnt = 0;
for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++)
{
topk_ids_host.mData[i] = e_cnt;
e_cnt++;
if(e_cnt >= experts)
e_cnt = 0;
}
}
else
{
topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
}
// leave it here for future debug purpose
#if 0
a_host.loadtxt("../../ater/input_torch.txt");
topk_ids_host.loadtxt("../../ater/topk_ids_torch.txt", "int");
// topk_ids_host.savetxt("topk_ids_2.txt");
topk_weight_host.loadtxt("../../ater/topk_weights_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
g_host.loadtxt("../../ater/w1_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
d_host.loadtxt("../../ater/w2_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
#endif
#if 0
std::cout << "sorted_token_ids_host:" << sorted_token_ids_host << std::endl;
std::cout << "num_sorted_tiles_host:" << num_sorted_tiles_host << std::endl;
std::cout << "sorted_expert_ids_host:" << sorted_expert_ids_host << std::endl;
std::cout << "topk_weight_host:" << topk_weight_host << std::endl;
std::cout << "sorted_weight_host:" << sorted_weight_host << std::endl;
#endif
auto cal_tflops = [&](auto ms) {
double flop_gemm_0 =
2 * static_cast<double>(tokens) * topk * shared_intermediate_size_0 * hidden_size;
double flop_gemm_1 =
2 * static_cast<double>(tokens) * topk * shared_intermediate_size_1 * hidden_size;
return (flop_gemm_0 + flop_gemm_1) / (static_cast<double>(ms) * 1e-3) / 1e12;
};
// TODO: this method we use expert-by-expert view, just for reference
auto cal_tbps = [&](auto ms) {
double token_bytes =
static_cast<double>(tokens) * topk / experts * hidden_size * sizeof(ADataType);
double w0_bytes = static_cast<double>(shared_intermediate_size_0) * experts * hidden_size *
sizeof(GDataType);
double w1_bytes = static_cast<double>(shared_intermediate_size_1) * experts * hidden_size *
sizeof(DDataType);
double o_bytes =
static_cast<double>(tokens) * topk / experts * hidden_size * sizeof(ODataType);
double topk_weights_bytes = static_cast<double>(tokens) * topk * sizeof(TopkWeightDataType);
// ignore index, they are too small
return (token_bytes + w0_bytes + w1_bytes + o_bytes + topk_weights_bytes) /
(static_cast<double>(ms) * 1e-3) / 1e12;
};
if(api == 0)
{
ck_tile::DeviceMem a_buf(a_host);
ck_tile::DeviceMem g_perm_buf(g_perm_host);
ck_tile::DeviceMem d_perm_buf(d_perm_host);
ck_tile::DeviceMem sa_buf(sa_host);
ck_tile::DeviceMem sg_buf(sg_host);
ck_tile::DeviceMem sd_buf(sd_host);
ck_tile::DeviceMem sy_buf(sy_host);
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem topk_ids_buf(topk_ids_host);
ck_tile::DeviceMem topk_weight_buf(topk_weight_host);
ck_tile::DeviceMem sorted_token_ids_buf(
sorted_token_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_expert_ids_buf(
sorted_expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem num_sorted_tiles_buf(
num_sorted_tiles_host.get_element_space_size_in_bytes());
fused_moe_traits traits{prec_i,
prec_w,
prec_o,
prec_st,
prec_sw,
prec_sq,
prec_kw,
block_m,
gate_only,
fused_quant};
fused_moe_args args{a_buf.GetDeviceBuffer(),
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
g_perm_buf.GetDeviceBuffer(),
d_perm_buf.GetDeviceBuffer(),
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
o_buf.GetDeviceBuffer(),
topk_ids_buf.GetDeviceBuffer(),
topk_weight_buf.GetDeviceBuffer(),
sorted_token_ids_buf.GetDeviceBuffer(),
sorted_weight_buf.GetDeviceBuffer(),
sorted_expert_ids_buf.GetDeviceBuffer(),
num_sorted_tiles_buf.GetDeviceBuffer(),
block_m,
hidden_size,
shared_intermediate_size_0,
tokens,
experts,
topk,
stride};
float ave_time = fused_moe(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
if(ave_time < 0)
{
std::cout << " not supported!" << std::endl << std::flush;
return false;
}
// float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, "
<< cal_tbps(ave_time) << " TB/s" << std::flush;
bool pass = true;
if(do_validation)
{
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host,
topk_weight_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host.mData[0],
experts,
block_m);
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
a_host,
g_host,
d_host,
sa_host,
sg_host,
sd_host,
sy_host,
o_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host,
topk_ids_host,
block_m,
tokens,
experts,
hidden_size,
shared_intermediate_size_0,
topk,
gate_only);
auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float");
auto [rtol, atol] = get_elimit<ADataType>();
pass &= ck_tile::check_err(
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
}
std::cout << std::flush << std::endl;
return pass;
}
else if(api == 1)
{
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host,
topk_weight_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host.mData[0],
experts,
block_m);
// done, preparing GPU buffer
ck_tile::DeviceMem a_buf(a_host);
ck_tile::DeviceMem g_perm_buf(g_perm_host);
ck_tile::DeviceMem d_perm_buf(d_perm_host);
ck_tile::DeviceMem sa_buf(sa_host);
ck_tile::DeviceMem sg_buf(sg_host);
ck_tile::DeviceMem sd_buf(sd_host);
ck_tile::DeviceMem sy_buf(sy_host);
ck_tile::DeviceMem o_buf(o_host);
// manually clear output buffer for atomic
o_buf.SetZero();
//
ck_tile::DeviceMem sorted_token_ids_buf(sorted_token_ids_host);
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host);
ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host);
ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host);
fused_moegemm_traits traits{prec_i,
prec_w,
prec_o,
prec_st,
prec_sw,
prec_sq,
prec_kw,
block_m,
gate_only,
fused_quant};
fused_moegemm_args args{a_buf.GetDeviceBuffer(),
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
g_perm_buf.GetDeviceBuffer(),
d_perm_buf.GetDeviceBuffer(),
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
o_buf.GetDeviceBuffer(),
sorted_token_ids_buf.GetDeviceBuffer(),
sorted_weight_buf.GetDeviceBuffer(),
sorted_expert_ids_buf.GetDeviceBuffer(),
num_sorted_tiles_buf.GetDeviceBuffer(),
hidden_size,
shared_intermediate_size_0,
tokens,
experts,
topk,
stride};
float ave_time = fused_moegemm(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
if(ave_time < 0)
{
std::cout << " not supported!" << std::endl << std::flush;
return false;
}
// float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, "
<< cal_tbps(ave_time) << " TB/s" << std::flush;
bool pass = true;
if(do_validation)
{
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
a_host,
g_host,
d_host,
sa_host,
sg_host,
sd_host,
sy_host,
o_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host,
topk_ids_host,
block_m,
tokens,
experts,
hidden_size,
shared_intermediate_size_0,
topk,
gate_only);
auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float");
auto [rtol, atol] = get_elimit<ADataType>();
pass &= ck_tile::check_err(
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
}
std::cout << std::flush << std::endl;
return pass;
}
return false;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_w = arg_parser.get_str("prec_w");
std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_st = arg_parser.get_str("prec_st");
std::string prec_sw = arg_parser.get_str("prec_sw");
std::string prec_sq = arg_parser.get_str("prec_sq");
std::string prec_kw = arg_parser.get_str("prec_kw");
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
// no dynamic quant case
if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32")
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float>(
arg_parser)
? 0
: -2;
}
else if(prec_i == "fp16" && prec_w == "fp16" && prec_o == "fp16" && prec_kw == "fp32")
{
return run<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float>(
arg_parser)
? 0
: -2;
}
return -3;
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 90 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

View File

@@ -14,3 +14,5 @@ add_subdirectory(11_add_rmsnorm2d_rdquant)
add_subdirectory(12_smoothquant)
add_subdirectory(13_moe_sorting)
add_subdirectory(14_moe_smoothquant)
add_subdirectory(15_fused_moe)

View File

@@ -52,6 +52,7 @@
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
@@ -62,6 +63,7 @@
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/static_counter.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"

View File

@@ -621,6 +621,65 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
}
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add_if;
template <bool pre_nop>
struct buffer_atomic_add_if<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(v_offset),
"v"(bit_cast<mbuf_t>(value)),
"s"(res.xy),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add;
template <bool pre_nop>
struct buffer_atomic_add<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag = 1*/)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile("global_atomic_pk_add_bf16 %0, %1, %2 offset:%3"
:
: "v"(v_offset), "v"(bit_cast<mbuf_t>(value)), "s"(res.xy), "n"(i_offset)
: "memory");
}
};
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off
@@ -810,6 +869,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
// buffer load i8
CK_TILE_DEVICE_EXTERN int8_t
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
@@ -2378,6 +2442,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
#endif
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_atomic_add_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const index_t dst_linear_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size,
bool_constant<pre_nop> = {})
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
if constexpr(oob_conditional_check)
{
buffer_atomic_add_if<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
dst_thread_element_valid);
}
else
{
buffer_atomic_add<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
1);
}
}
// buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.

View File

@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds()
#endif
}
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
{
#ifdef __gfx12__
asm volatile("s_wait_loadcnt %0 \n"
"s_barrier_signal -1 \n"
"s_barrier_wait -1"
:
: "n"(cnt)
: "memory");
#else
asm volatile("s_waitcnt vmcnt(%0) \n"
"s_barrier"
:
: "n"(cnt)
: "memory");
#endif
}
CK_TILE_DEVICE void block_sync_lds_direct_load()
{
asm volatile("\

View File

@@ -102,4 +102,28 @@ CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
#endif
}
template <typename T>
CK_TILE_DEVICE auto flag_to_exec(const T& v_flag)
{
static_assert(sizeof(T) == 4);
// per-thread v_flag store into 2x sgpr
uint32x2_t exec_flag;
asm volatile("v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1"
: [s_exec_flag] "=s"(exec_flag)
: [v_flag] "v"(v_flag));
return exec_flag;
}
template <typename X, typename Y>
CK_TILE_DEVICE auto cmp_lt_to_exec(const X& x, const Y& y)
{
static_assert(sizeof(X) == 4 && sizeof(Y) == 4);
// per-thread cmp store into 2x sgpr
uint32x2_t exec_flag;
asm volatile("v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]"
: [s_exec_flag] "=s"(exec_flag)
: [v_x] "v"(x), [v_y] "v"(y));
return exec_flag;
}
} // namespace ck_tile

View File

@@ -437,34 +437,74 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
CK_TILE_DEVICE void update(index_t i,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {})
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set<X>(i, linear_offset, is_valid_element, x);
this->template set<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_add)
{
this->template atomic_add<X>(i, linear_offset, is_valid_element, x);
this->template atomic_add<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_max)
{
this->template atomic_max<X>(i, linear_offset, is_valid_element, x);
this->template atomic_max<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x);
}
// FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add)
{
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
auto tmp =
this->template get<X, oob_conditional_check>(i, linear_offset, is_valid_element);
this->template set<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x + tmp);
// tmp += x;
// this->template set<X>(i, is_valid_element, tmp);
}
}
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update_raw(index_t i,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set_raw<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_add)
{
this->template atomic_add_raw<X, oob_conditional_check, pre_nop>(
i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_max)
{
// this->template atomic_max_raw<X>(i, linear_offset, is_valid_element, x);
}
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
@@ -533,6 +573,7 @@ struct buffer_view<address_space_enum::global,
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
@@ -585,6 +626,39 @@ struct buffer_view<address_space_enum::global,
}
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void
atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
// using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
static_assert(get_address_space() == address_space_enum::global, "only support global mem");
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add_raw<remove_cvref_t<T>,
t_per_x,
Coherence,
oob_conditional_check,
pre_nop>(
x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,

View File

@@ -22,28 +22,32 @@ template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
@@ -51,15 +55,35 @@ template <typename DistributedTensor_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, bool_constant<oob_conditional_check>{});
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
/**
@@ -76,6 +100,7 @@ template <typename T,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
@@ -83,11 +108,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename T,
@@ -95,6 +121,7 @@ template <typename T,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
@@ -102,11 +129,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
@@ -114,6 +142,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto
@@ -122,11 +151,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
@@ -134,6 +166,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
@@ -141,11 +174,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)

View File

@@ -201,4 +201,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
return unpacks;
}
namespace detail {
// check if 2 static_distributed_tensor has same data type and size of element
// but only difference in distribution
template <typename X, typename Y>
struct is_similiar_distributed_tensor
{
static constexpr bool value = false;
};
template <typename TypeX, typename DistX, typename TypeY, typename DistY>
struct is_similiar_distributed_tensor<static_distributed_tensor<TypeX, DistX>,
static_distributed_tensor<TypeY, DistY>>
{
using Tx = static_distributed_tensor<TypeX, DistX>;
using Ty = static_distributed_tensor<TypeY, DistY>;
static constexpr bool value = std::is_same_v<typename Tx::DataType, typename Ty::DataType> &&
Tx::get_thread_buffer_size() == Ty::get_thread_buffer_size();
};
template <typename X, typename Y>
inline constexpr bool is_similiar_distributed_tensor_v =
is_similiar_distributed_tensor<X, Y>::value;
} // namespace detail
} // namespace ck_tile

View File

@@ -333,6 +333,48 @@ struct tensor_view
coord.get_offset(), linear_offset, is_valid_element, x);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
update_vectorized_elements_raw(const TensorCoord& coord,
index_t linear_offset,
const X& x,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
update_vectorized_elements_raw(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
coord.get_offset(), linear_offset, is_valid_element, x);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_view{");

View File

@@ -292,12 +292,15 @@ struct tile_window_with_static_distribution
{
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
load(dst_tensor, bool_constant<oob_conditional_check>{});
load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
return dst_tensor;
}
template <typename DistributedTensor, bool oob_conditional_check = true>
template <typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
@@ -785,6 +788,73 @@ struct tile_window_with_static_distribution
});
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true, bool pre_nop>
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin

View File

@@ -432,23 +432,38 @@ struct tile_window_linear
CK_TILE_DEVICE static constexpr index_t get_bottom_linear_offset(number<i_access>)
{
constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{});
// since this is linear offset, we assum bottom X tensor is always linear
constexpr index_t linear_offset = [&]() {
constexpr auto x_idx_ = linear_coord;
constexpr auto x_len_ = TileDstr{}.get_lengths();
static_assert(x_idx_.size() == x_len_.size());
constexpr index_t x_dims_ = x_idx_.size();
index_t cu_stride_ = 1;
index_t cu_offset_ = 0;
static_for<0, x_dims_, 1>{}([&](auto i_) {
auto r_i_ = number<x_dims_ - i_ - 1>{};
cu_offset_ += x_idx_[r_i_] * cu_stride_;
cu_stride_ *= x_len_[r_i_];
});
return cu_offset_;
}();
return linear_offset;
constexpr auto is_pure_linear_tensor =
reduce_on_sequence(LinearBottomDims{}, multiplies{}, number<1>{});
if constexpr(is_pure_linear_tensor)
{
// this case usually is a LDS window, everything is known at compile tile.
// we directly use BottomTensorView transform to compute the offset, in case padding
auto bottom_tensor_coord =
make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
return bottom_tensor_coord.get_offset();
}
else
{
// this case usually is a global window, where last dim can be linear
// we hack here, that use the original TileDstr to compute the linear offset
// ... hoping that there is no extra padding between other dims, which make sense
// since that would introduce runtime length (so can't use linear offset)
constexpr index_t linear_offset = [&]() {
constexpr auto x_idx_ = linear_coord;
constexpr auto x_len_ = TileDstr{}.get_lengths();
static_assert(x_idx_.size() == x_len_.size());
constexpr index_t x_dims_ = x_idx_.size();
index_t cu_stride_ = 1;
index_t cu_offset_ = 0;
static_for<0, x_dims_, 1>{}([&](auto i_) {
auto r_i_ = number<x_dims_ - i_ - 1>{};
cu_offset_ += x_idx_[r_i_] * cu_stride_;
cu_stride_ *= x_len_[r_i_];
});
return cu_offset_;
}();
return linear_offset;
}
}
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; }
@@ -509,6 +524,64 @@ struct tile_window_linear
return dst_tensor;
}
template <typename DstTile, index_t i_access = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DstTile& dst_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
auto bottom_tensor_flag = cached_flags_[IAccess];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
bool_constant<oob_conditional_check>{});
#if 1
// data index [y0, y1, ...]
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
// write into distributed tensor
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
};
WINDOW_DISPATCH_ISSUE();
return dst_tensor;
}
template <typename DstTile,
index_t i_access = -1,
bool oob_conditional_check = true,
@@ -849,6 +922,58 @@ struct tile_window_linear
WINDOW_DISPATCH_ISSUE();
}
template <index_t i_access = -1, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
auto bottom_tensor_flag = cached_flags_[IAccess];
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
vec_value,
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
};
WINDOW_DISPATCH_ISSUE();
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin

View File

@@ -0,0 +1,54 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#pragma once
namespace ck_tile {
// input a lds store tile, extract some information from it
// used to set m0 value for gfx9 serious
template <typename LdsTileWindow_>
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_&& lds_tile)
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
// issues * warps * lanes
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
const index_t size_per_buf =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType);
const index_t size_per_wave =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t size_per_issue =
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
sizeof(LdsDataType) -
size_per_buf;
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
return make_tuple(m0_init_value, size_per_issue);
}
} // namespace ck_tile

View File

@@ -41,15 +41,65 @@ template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_>
typename DataType_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE void
update_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
tile_window.update(dstr_tensor);
tile_window.update(dstr_tensor, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void
update_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.update_raw(dstr_tensor,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
typename DataType_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto update_tile_raw(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.update_raw(dstr_tensor,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
} // namespace ck_tile

View File

@@ -0,0 +1,116 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace ck_tile {
template <typename Context, index_t Start = 0, index_t Step = 1>
struct static_counter
{
public:
template <typename Unique>
static constexpr index_t next()
{
return next<Unique>(0) * Step + Start;
}
template <unsigned long long>
static constexpr index_t next()
{
struct Unique
{
};
return next<Unique>(0) * Step + Start;
}
template <typename Unique>
static constexpr index_t current()
{
return current<Unique>(0) * Step + Start;
}
template <unsigned long long>
static constexpr index_t current()
{
struct Unique
{
};
return current<Unique>(0) * Step + Start;
}
private:
template <index_t I>
struct slot
{
_Pragma("GCC diagnostic push");
_Pragma("GCC diagnostic ignored \"-Wundefined-internal\"");
friend constexpr bool slot_allocated(slot<I>);
_Pragma("GCC diagnostic pop");
};
template <index_t I>
struct allocate_slot
{
friend constexpr bool slot_allocated(slot<I>) { return true; }
enum
{
value = I
};
};
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template <typename Unique, index_t I = 0, bool = slot_allocated(slot<I>())>
static constexpr index_t next(index_t)
{
return next<Unique, I + 1>(0);
}
// ...And this function will be used, instead, which will define slot_allocated(slot<I>) via
// allocate_slot<I>.
template <typename Unique, index_t I = 0>
static constexpr index_t next(double)
{
return allocate_slot<I>::value;
}
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template <typename Unique, index_t I = Start, bool = slot_allocated(slot<I>())>
static constexpr index_t current(index_t)
{
return current<Unique, I + 1>(0);
}
// ...And this function will be used, instead, which will return the current counter, or assert
// in case next() hasn't been called yet.
template <typename Unique, index_t I = Start>
static constexpr index_t current(double)
{
static_assert(I != 0, "You must invoke next() first");
return I - 1;
}
};
namespace impl {
template <int I>
struct static_counter_uniq_;
}
#define MAKE_SC() \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>> {}
#define MAKE_SC_WITH(start_, step_) \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_> {}
#define NEXT_SC(c_) c_.next<__COUNTER__>()
#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>()
// Usage:
// constexpr auto c = MAKE_SC()
// NEXT_SC(c) // -> constexpr 0
// NEXT_SC(c) // -> constexpr 1
// NEXT_SC(c) // -> constexpr 2
} // namespace ck_tile

View File

@@ -11,6 +11,7 @@
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
@@ -20,6 +21,7 @@
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"

View File

@@ -7,6 +7,7 @@
#include <stdint.h>
#include <stdexcept>
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <typename T>
@@ -36,6 +37,19 @@ struct DeviceMem
mpDeviceBuf = nullptr;
}
}
template <typename T>
DeviceMem(const HostTensor<T>& t) : mMemSize(t.get_element_space_size_in_bytes())
{
if(mMemSize != 0)
{
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
else
{
mpDeviceBuf = nullptr;
}
ToDevice(t.data());
}
void Realloc(std::size_t mem_size)
{
if(mpDeviceBuf)
@@ -92,6 +106,27 @@ struct DeviceMem
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
}
}
// construct a host tensor with type T
template <typename T>
HostTensor<T> ToHost(std::size_t cpySize)
{
// TODO: host tensor could be slightly larger than the device tensor
// we just copy all data from GPU buffer
std::size_t host_elements = (cpySize + sizeof(T) - 1) / sizeof(T);
HostTensor<T> h_({host_elements});
if(mpDeviceBuf)
{
HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
}
return h_;
}
template <typename T>
HostTensor<T> ToHost()
{
return ToHost<T>(mMemSize);
}
void SetZero() const
{
if(mpDeviceBuf)

View File

@@ -13,6 +13,7 @@
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/host/joinable_thread.hpp"
namespace ck_tile {
@@ -22,13 +23,44 @@ struct FillUniformDistribution
float a_{-5.f};
float b_{5.f};
std::optional<uint32_t> seed_{11939};
// ATTENTION: threaded does not guarantee the distribution between thread
bool threaded = false;
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
if(threaded)
{
uint32_t num_thread = std::thread::hardware_concurrency();
auto total = static_cast<std::size_t>(std::distance(first, last));
auto work_per_thread = static_cast<std::size_t>((total + num_thread - 1) / num_thread);
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t iw_begin = it * work_per_thread;
std::size_t iw_end = std::min((it + 1) * work_per_thread, total);
auto thread_f = [this, total, iw_begin, iw_end, &first] {
if(iw_begin > total || iw_end > total)
return;
// need to make each thread unique, add an offset to current seed
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
: std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
return ck_tile::type_convert<T>(dis(gen));
});
};
threads[it] = joinable_thread(thread_f);
}
}
else
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::uniform_real_distribution<float> dis(a_, b_);
std::generate(
first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
}
}
template <typename ForwardRange>
@@ -115,13 +147,44 @@ struct FillNormalDistribution
float mean_{0.f};
float variance_{1.f};
std::optional<uint32_t> seed_{11939};
// ATTENTION: threaded does not guarantee the distribution between thread
bool threaded = false;
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
if(threaded)
{
uint32_t num_thread = std::thread::hardware_concurrency();
auto total = static_cast<std::size_t>(std::distance(first, last));
auto work_per_thread = static_cast<std::size_t>((total + num_thread - 1) / num_thread);
std::vector<joinable_thread> threads(num_thread);
for(std::size_t it = 0; it < num_thread; ++it)
{
std::size_t iw_begin = it * work_per_thread;
std::size_t iw_end = std::min((it + 1) * work_per_thread, total);
auto thread_f = [this, total, iw_begin, iw_end, &first] {
if(iw_begin > total || iw_end > total)
return;
// need to make each thread unique, add an offset to current seed
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
: std::random_device{}());
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
return ck_tile::type_convert<T>(dis(gen));
});
};
threads[it] = joinable_thread(thread_f);
}
}
else
{
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
std::generate(
first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
}
}
template <typename ForwardRange>
@@ -235,6 +298,44 @@ struct FillMonotonicSeq
}
};
template <typename T, bool IsAscending = true>
struct FillStepRange
{
float start_value_{0};
float end_value_{3};
float step_{1};
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::generate(first, last, [=, n = start_value_]() mutable {
auto tmp = n;
n += step_;
if constexpr(IsAscending)
{
if(n > end_value_)
n = start_value_;
}
else
{
if(n < end_value_)
n = start_value_;
}
return type_convert<T>(tmp);
});
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const -> std::void_t<
decltype(std::declval<const FillStepRange&>()(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
};
template <typename T>
struct FillConstant
{

View File

@@ -8,12 +8,13 @@
#include <iostream>
#include <iomanip>
#include <numeric>
#include <thread>
#include <utility>
#include <vector>
#include <functional>
#include <fstream>
#include "ck_tile/core.hpp"
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/ranges.hpp"
namespace ck_tile {
@@ -213,23 +214,6 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
return HostTensorDescriptor(new_lengths, new_strides);
}
struct joinable_thread : std::thread
{
template <typename... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
{
}
joinable_thread(joinable_thread&&) = default;
joinable_thread& operator=(joinable_thread&&) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
template <typename F, typename... Xs>
struct ParallelTensorFunctor
{
@@ -590,6 +574,107 @@ struct HostTensor
size() * FromSize / ToSize};
}
friend std::ostream& operator<<(std::ostream& os, const HostTensor<T>& t)
{
os << t.mDesc;
os << "[";
for(typename Data::size_type idx = 0; idx < t.mData.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t>)
{
os << type_convert<float>(t.mData[idx]) << " #### ";
}
else
{
os << t.mData[idx];
}
}
os << "]";
return os;
}
// read data from a file, as dtype
// the file could dumped from torch as (targeting tensor is t here)
// numpy.savetxt("f.txt", t.view(-1).numpy())
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy()) # from cuda to cpu to save
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy(), fmt="%d") # save as int
// will output f.txt, each line is a value
// dtype=float or int, internally will cast to real type
void loadtxt(std::string file_name, std::string dtype = "float")
{
std::ifstream file(file_name);
if(file.is_open())
{
std::string line;
index_t cnt = 0;
while(std::getline(file, line))
{
if(cnt >= static_cast<index_t>(mData.size()))
{
throw std::runtime_error(std::string("data read from file:") + file_name +
" is too big");
}
if(dtype == "float")
{
mData[cnt] = type_convert<T>(std::stof(line));
}
else if(dtype == "int" || dtype == "int32")
{
mData[cnt] = type_convert<T>(std::stoi(line));
}
cnt++;
}
file.close();
if(cnt < static_cast<index_t>(mData.size()))
{
std::cerr << "Warning! reading from file:" << file_name
<< ", does not match the size of this tensor" << std::endl;
}
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw std::runtime_error(std::string("unable to open file:") + file_name);
}
}
// can save to a txt file and read from torch as:
// torch.from_numpy(np.loadtxt('f.txt', dtype=np.int32/np.float32...)).view([...]).contiguous()
void savetxt(std::string file_name, std::string dtype = "float")
{
std::ofstream file(file_name);
if(file.is_open())
{
for(auto& itm : mData)
{
if(dtype == "float")
file << type_convert<float>(itm) << std::endl;
else if(dtype == "int")
file << type_convert<int>(itm) << std::endl;
else
// TODO: we didn't implement operator<< for all custom
// data types, here fall back to float in case compile error
file << type_convert<float>(itm) << std::endl;
}
file.close();
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw std::runtime_error(std::string("unable to open file:") + file_name);
}
}
Descriptor mDesc;
Data mData;
};

View File

@@ -0,0 +1,27 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <thread>
#include <utility>
namespace ck_tile {
struct joinable_thread : std::thread
{
template <typename... Xs>
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
{
}
joinable_thread(joinable_thread&&) = default;
joinable_thread& operator=(joinable_thread&&) = default;
~joinable_thread()
{
if(this->joinable())
this->join();
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,196 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
// number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
// 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4
// -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
// c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
///
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
template <typename AccDataType, // you only need to explcitly set this one
typename Activation, // ck_tile::element_wise::Gelu
typename ADataType,
typename GDataType,
typename DDataType,
typename ODataType,
typename AScaleDataType,
typename GScaleDataType,
typename DScaleDataType,
typename YSmoothScaleDataType,
typename TopkWeightDataType,
typename IndexDataType>
void reference_fused_moe(
const ck_tile::HostTensor<ADataType>& a_host, // [tokens, hidden_size]
const ck_tile::HostTensor<GDataType>& g_host, // [experts, interme_size_0, hidden_size]
const ck_tile::HostTensor<DDataType>& d_host, // [experts, hidden_size, interme_size_1]
const ck_tile::HostTensor<AScaleDataType>& sa_host, // [tokens, 1],
const ck_tile::HostTensor<GScaleDataType>& sg_host, // [experts, 1, interme_size_0]
const ck_tile::HostTensor<DScaleDataType>& sd_host, // [experts, 1, hidden_size],
const ck_tile::HostTensor<YSmoothScaleDataType>& sy_host, // [experts, 1, interme_size_0]
ck_tile::HostTensor<ODataType>& o_host, // [tokens, hidden_size]
const ck_tile::HostTensor<IndexDataType>& sorted_token_ids_host, // [max_num_tokens_padded]
const ck_tile::HostTensor<TopkWeightDataType>& sorted_weight_host, // [max_num_tokens_padded]
const ck_tile::HostTensor<IndexDataType>&
sorted_expert_ids_host, // [(max_num_tokens_padded + block_size - 1) / block_size]
const ck_tile::HostTensor<IndexDataType>& num_sorted_tiles_host, // [1]
const ck_tile::HostTensor<IndexDataType>&
token_ids_host, // [tokens, topk] --> ugly!!! remove in the future
ck_tile::index_t block_m,
ck_tile::index_t tokens,
ck_tile::index_t experts,
ck_tile::index_t hidden_size,
ck_tile::index_t intermediate_size, // this size is for gate/up
ck_tile::index_t topk,
ck_tile::index_t gate_only)
{
assert(sorted_token_ids_host.get_num_of_dimension() == 1);
assert(sorted_weight_host.get_num_of_dimension() == 1);
assert(sorted_expert_ids_host.get_num_of_dimension() == 1);
assert(num_sorted_tiles_host.get_element_size() == 1);
ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0] / block_m;
ck_tile::index_t intermediate_size_0 = intermediate_size;
ck_tile::index_t intermediate_size_1 = intermediate_size / (gate_only ? 1 : 2);
// TODO: better remove this in the future, or modify the token_id value
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
{
if(token_ids_host(token_id_, i_) == expert_id_)
return i_;
}
throw std::runtime_error("not correct token/expert pair\n");
return -1; // TODO: not correct!!
};
ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
// assert();
auto f = [&](auto i_flatten) {
ck_tile::index_t i_tile = i_flatten / block_m;
if(i_tile >= num_sorted_tiles)
return;
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
if(i_token >= tokens)
return;
ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
auto weight = sorted_weight_host.mData[i_flatten];
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
// first gemm
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++)
{
AccDataType acc = static_cast<AccDataType>(0);
for(ck_tile::index_t i_k = 0; i_k < hidden_size; i_k++)
{
acc += type_convert<AccDataType>(a_host(i_token, i_k)) *
type_convert<AccDataType>(g_host(i_expert, i_n, i_k));
}
acc_0(0, i_n) = acc;
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
}
ck_tile::HostTensor<AccDataType> y({1, intermediate_size_1});
if(gate_only)
{
if(intermediate_size_1 != intermediate_size_0)
throw std::runtime_error(
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
", 1:" + std::to_string(intermediate_size_1));
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{
Activation{}(y(0, i_n), acc_0(0, i_n));
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
}
}
else
{
if(intermediate_size_1 * 2 != intermediate_size_0)
throw std::runtime_error(
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
", 1:" + std::to_string(intermediate_size_1));
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{
AccDataType tmp;
Activation{}(tmp, acc_0(0, i_n));
y(0, i_n) = tmp * acc_0(0, i_n + intermediate_size_1); // TODO: elementwise mul
}
}
// second gemm, loop along gemm-n
ck_tile::HostTensor<AccDataType> acc_1({1, hidden_size});
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
AccDataType acc = static_cast<AccDataType>(0);
for(ck_tile::index_t i_k = 0; i_k < intermediate_size_1; i_k++)
{
acc += y(0, i_k) * type_convert<AccDataType>(d_host(i_expert, i_n, i_k));
}
acc_1(0, i_n) = acc * weight; // multiple weight here
}
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
out_topk_tokens(i_token, i_topk, i_n) = acc_1(0, i_n);
}
};
// make_ParallelTensorFunctor(f, max_num_tokens_padded)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f, max_num_tokens_padded)(1);
// reduce
auto r = [&](auto i_token) {
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
AccDataType acc = type_convert<AccDataType>(0);
for(ck_tile::index_t i_topk = 0; i_topk < topk; i_topk++)
{
acc += out_topk_tokens(i_token, i_topk, i_n);
}
o_host(i_token, i_n) = type_convert<ODataType>(acc);
}
};
make_ParallelTensorFunctor(r, tokens)(std::thread::hardware_concurrency());
(void)num_sorted_tiles_host;
(void)sa_host;
(void)sg_host;
(void)sd_host;
(void)sy_host;
}
} // namespace ck_tile

View File

@@ -16,7 +16,7 @@ namespace ck_tile {
*/
template <typename DataType>
CK_TILE_HOST void
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> dims)
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> perm)
{
const auto x_len = x.mDesc.get_lengths();
const auto y_len = y.mDesc.get_lengths();
@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
std::vector<size_t> tmp(rank, 0);
for(index_t i = 0; i < rank; i++)
{
tmp[dims[i]] = y_coord[i];
tmp[perm[i]] = y_coord[i];
}
return tmp;
}();
@@ -54,4 +54,23 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency());
}
template <typename DataType>
CK_TILE_HOST auto reference_permute(const HostTensor<DataType>& x, std::vector<index_t> perm)
{
auto x_shape = x.get_lengths();
ck_tile::index_t rank = perm.size();
std::vector<ck_tile::index_t> y_shape = [&]() {
std::vector<ck_tile::index_t> tmp(rank, 0);
for(int i = 0; i < static_cast<int>(rank); i++)
{
tmp[i] = x_shape[perm[i]];
}
return tmp;
}();
HostTensor<DataType> y(y_shape);
reference_permute(x, y, perm);
return y;
}
} // namespace ck_tile

View File

@@ -572,6 +572,105 @@ struct FastGelu
}
};
struct FastGeluAsm
{
template <typename Y, typename X>
CK_TILE_HOST void operator()(Y& y, const X& x) const;
template <typename Y, typename X>
CK_TILE_DEVICE void operator()(Y& y, const X& x) const;
template <>
CK_TILE_HOST void operator()<float, float>(float& y, const float& x) const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u = x * (c1 * x * x + c2);
const float emu = exp(u);
y = x / (1.f + emu);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template <>
CK_TILE_DEVICE void operator()<float, float>(float& y, const float& x) const
{
const uint32_t c1 = 0xbd92220c; // -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v<float>;
float tmp;
asm volatile("v_mul_f32 %[v_tmp], %[v_x], %[v_x] ; x*x\n"
"v_fma_f32 %[v_tmp], %[v_tmp], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_mul_f32 %[v_tmp], %[v_tmp], %[v_x] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp], %[v_tmp], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_exp_f32 %[v_tmp], %[v_tmp] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"s_nop 0 ; hazard for exp\n"
"v_add_f32 %[v_tmp], %[v_tmp], 1.0 ; emu+1.0f\n"
"v_rcp_f32 %[v_tmp], %[v_tmp] ; 1/(emu+1.0f)\n"
"s_nop 0 ; hazard for rcp \n"
"v_mul_f32 %[v_y], %[v_tmp], %[v_x] ; x * 1/(emu+1f)\n"
: [v_y] "=v"(y), [v_tmp] "+v"(tmp)
: [v_x] "v"(x), [s_c1] "s"(c1), [v_c2] "v"(c2), [s_log2e] "s"(log2e_)
:);
}
template <>
CK_TILE_HOST void operator()<fp32x2_t, fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
{
const float c1 = -2.0 * 0.035677f;
const float c2 = -2.0 * 0.797885f;
const float u0 = x.x * (c1 * x.x * x.x + c2);
const float emu0 = exp(u0);
y.x = x.x / (1.f + emu0);
const float u1 = x.y * (c1 * x.y * x.y + c2);
const float emu1 = exp(u1);
y.y = x.y / (1.f + emu1);
}
// this is packed verion to remove data hazard for trans
template <>
CK_TILE_DEVICE void operator()<fp32x2_t, fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
{
const uint32_t c1 = 0xbd92220c; // -2.0 * 0.035677f;
float c2 = -2.0 * 0.797885f;
const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v<float>;
float tmp0, tmp1;
float y0 = x.x, y1 = x.y;
asm volatile(
"v_mul_f32 %[v_tmp0], %[v_y0], %[v_y0] ; x*x\n"
"v_mul_f32 %[v_tmp1], %[v_y1], %[v_y1] ; x*x\n"
"v_fma_f32 %[v_tmp0], %[v_tmp0], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_fma_f32 %[v_tmp1], %[v_tmp1], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[v_y0] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[v_y1] ; x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
"v_exp_f32 %[v_tmp0], %[v_tmp0] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"v_exp_f32 %[v_tmp1], %[v_tmp1] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
"v_add_f32 %[v_tmp0], %[v_tmp0], 1.0 ; emu+1.0f\n"
"v_add_f32 %[v_tmp1], %[v_tmp1], 1.0 ; emu+1.0f\n"
"v_rcp_f32 %[v_tmp0], %[v_tmp0] ; 1/(emu+1.0f)\n"
"v_rcp_f32 %[v_tmp1], %[v_tmp1] ; 1/(emu+1.0f)\n"
"v_mul_f32 %[v_y0], %[v_tmp0], %[v_y0] ; x * 1/(emu+1f)\n"
"v_mul_f32 %[v_y1], %[v_tmp1], %[v_y1] ; x * 1/(emu+1f)\n"
: [v_y0] "+v"(y0),
[v_y1] "+v"(y1),
[v_c2] "+v"(c2),
// NOTE! it is totally possible that c2/y0/y1 share same register, they are all local
// tmp variables we need to expicitly hint compiler they may read+write, to allow
// allocate different register , the side effect is c2=** may issue for every such
// inline asm block
[v_tmp0] "+v"(tmp0),
[v_tmp1] "+v"(tmp1)
: [s_c1] "s"(c1), [s_log2e] "s"(log2e_)
:);
y.x = y0;
y.y = y1;
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu

View File

@@ -0,0 +1,10 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -0,0 +1,615 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
namespace ck_tile {
// A async load to LDS, B direct to AGPR
// B matrix preshuffled in br*kr*w
// require 4 wave, occupancy=1c
// agpr useage:256
// vgpr usage:64(A local) + 64(acc) + 8(os_a) + 8(os_b) = 144 (rem:112)
//
// for this gemm, 4 16x16x16 transposed layout
// input A vpgpr layout
// v0-v15: [ 0:15](gemm_m)x128(gemm_k)
// v16-v31: [16:31](gemm_m)x128(gemm_k)
// input B vpgpr layout
// v0-v15: [ 0: 15](gemm_n)x128(gemm_k)
// v16-v31: [ 64: 79](gemm_n)x128(gemm_k)
// ......................
// v111-v127: [448:463](gemm_n)x128(gemm_k)
// output C vpgpr layout
// v0-v3 : [ 0:15](gemm_m)x[ 0: 15](gemm_n)
// v4-v7 : [16:31](gemm_m)x[ 0: 15](gemm_n)
// v8-v11: [ 0:15](gemm_m)x[64: 79](gemm_n)
// v12-v15: [16:31](gemm_m)x[64: 79](gemm_n)
// ......................
// v56-v59: [ 0:15](gemm_m)x[448:463](gemm_n)
// v60-v63: [16:31](gemm_m)x[448:463](gemm_n)
struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
{
static constexpr index_t Block_M = 32;
static constexpr index_t Block_N = 512;
static constexpr index_t Block_K = 128;
static constexpr index_t WarpPerBlock_M = 1;
static constexpr index_t WarpPerBlock_N = 4;
static constexpr index_t WarpPerBlock_K = 1;
static constexpr index_t NumWarps = 4;
static constexpr index_t Warp_M = 16;
static constexpr index_t Warp_N = 16;
static constexpr index_t Warp_K = 32; // 16 * SubKPacks
static constexpr index_t BlockSize = 256;
static constexpr index_t SubKPacks = 2; // this is used to gurantee every threads can do dwordx4
// TODO: note Nr/Kr/W need consider SubKPacks
static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<2, 1>, // !! note here is different
sequence<0, 0>>{};
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution;
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
{
using CDataType = float;
constexpr auto c_block_dstr = MakeCBlockDist();
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
{
// A async->LDS
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack_; // pad between warps
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK > NumWarps)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr index_t wavesPerM = NumWarps / wavesPerK;
constexpr index_t NumIssues = Block_M / wavesPerM;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<NumIssues>{}),
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<LaneGroups>{}, // m1
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1
number<warpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
// template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
{
// load from LDS to register, every wave has same layout
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KPad = KPack_; // pad between warps
constexpr index_t kAMLane = 16;
constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 4;
constexpr index_t kKIter = 2;
static_assert(KPack_ == (kABKPerLane * kKIter));
constexpr auto lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<Repeat_M>{}, // m0 y
number<kAMLane>{}, // m1 p
number<Repeat_K>{}, // k0 y
number<kABKLane>{}, // k1 p
number<KPack_>{}), // k2 y-vector
make_tuple(number<kAMLane*(Block_K + KPad)>{}, // m0
number<Block_K + KPad>{}, // m1
number<kABKLane * KPack_>{}, // k0
number<KPack_>{}, // k1
number<1>{}), // k2
number<KPack_>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_merge_transform(make_tuple(number<Repeat_M>{}, number<kAMLane>{})),
make_merge_transform(
make_tuple(number<Repeat_K>{}, number<kABKLane>{}, number<KPack_>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
static constexpr auto GetGemm_AWarpEnc()
{
constexpr index_t kAMLane = 16;
constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 4;
constexpr index_t kKIter = 2;
using enc_ = tile_distribution_encoding<
sequence<>,
tuple<sequence<kAMLane>, sequence<kABKLane, kABKPerLane * kKIter>>,
tuple<sequence<2, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<1>>;
return enc_{};
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return 32 * (128 + 8) * sizeof(bf16_t);
}
};
struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16x16x32_Base
{
using ADataType = bf16_t;
using BDataType = bf16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template <typename ARes, typename ACoords, typename BRes, typename BCoords>
CK_TILE_DEVICE auto
operator()(const ARes& res_a,
const ACoords& cached_coords_a,
const BRes& res_b,
const BCoords& cached_coords_b,
CK_TILE_LDS_ADDR void* smem,
index_t k,
index_t tile_offset_a, // for each tile, the offset to move for each unroll
index_t tile_offset_b) // for each tile, the offset to move for each unroll
{
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N);
auto a_sst = make_tile_window(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
MakeLdsStoreDesc_A().get_lengths(),
{0, 0, 0});
auto a_sld = [&]() {
constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<WarpPerBlock_N>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
MakeLdsLoadDesc_A().get_lengths(),
{0, 0},
make_static_tile_distribution(a_block_dstr_encode));
}();
const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
constexpr auto smem_buf_size =
MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
static_assert(a_sld.get_num_of_access() == 8);
constexpr auto sld_os = generate_tuple(
[&](auto i_access) {
return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
},
number<a_sld.get_num_of_access()>{});
index_t loop_cnt = k / Block_K;
// this is the acc thread buffer
fp32x4_t v_acc[16]{.0f};
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
: [s_loop_cnt]"+s"(loop_cnt),
[v_acc_0]"+v"(v_acc[0]),
[v_acc_1]"+v"(v_acc[1]),
[v_acc_2]"+v"(v_acc[2]),
[v_acc_3]"+v"(v_acc[3]),
[v_acc_4]"+v"(v_acc[4]),
[v_acc_5]"+v"(v_acc[5]),
[v_acc_6]"+v"(v_acc[6]),
[v_acc_7]"+v"(v_acc[7]),
[v_acc_8]"+v"(v_acc[8]),
[v_acc_9]"+v"(v_acc[9]),
[v_acc_10]"+v"(v_acc[10]),
[v_acc_11]"+v"(v_acc[11]),
[v_acc_12]"+v"(v_acc[12]),
[v_acc_13]"+v"(v_acc[13]),
[v_acc_14]"+v"(v_acc[14]),
[v_acc_15]"+v"(v_acc[15]),
[s_mem_]"+r"(smem)
: [s_res_a0]"s"(res_a[0]),
[s_res_a1]"s"(res_a[1]),
[s_res_a2]"s"(res_a[2]),
[s_res_a3]"s"(res_a[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))),
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))),
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))),
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))),
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),
[s_m0_init]"s"(m0_init_value),
[s_size_per_issue]"s"(size_per_issue),
[smem_sz]"n"(smem_buf_size), //(smem_buf_size),
[sld_os_0]"n"(sld_os[number<0>{}].value),
[sld_os_1]"n"(sld_os[number<1>{}].value),
[sld_os_2]"n"(sld_os[number<2>{}].value),
[sld_os_3]"n"(sld_os[number<3>{}].value),
[sld_os_4]"n"(sld_os[number<4>{}].value),
[sld_os_5]"n"(sld_os[number<5>{}].value),
[sld_os_6]"n"(sld_os[number<6>{}].value),
[sld_os_7]"n"(sld_os[number<7>{}].value),
[s_tile_os_a]"s"(tile_offset_a_bytes),
[s_tile_os_b]"s"(tile_offset_b_bytes)
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
"s86", // s86 as tmp
"v64", "v65", "v66", "v67", "v68", "v69",
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79",
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89",
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99",
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107",
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115",
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123",
"v124", "v125", "v126", "v127"
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto c = MakeCBlockTile();
for(auto i = 0; i < 16; i++)
{
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
}
return c;
}
};
struct Flatmm_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_32x512x128_1x4x1_16x16x32_Base
{
using ADataType = fp16_t;
using BDataType = fp16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template <typename ARes, typename ACoords, typename BRes, typename BCoords>
CK_TILE_DEVICE auto
operator()(const ARes& res_a,
const ACoords& cached_coords_a,
const BRes& res_b,
const BCoords& cached_coords_b,
CK_TILE_LDS_ADDR void* smem,
index_t k,
index_t tile_offset_a, // for each tile, the offset to move for each unroll
index_t tile_offset_b) // for each tile, the offset to move for each unroll
{
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N);
auto a_sst = make_tile_window(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
MakeLdsStoreDesc_A().get_lengths(),
{0, 0, 0});
auto a_sld = [&]() {
constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<WarpPerBlock_N>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode =
detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
MakeLdsLoadDesc_A().get_lengths(),
{0, 0},
make_static_tile_distribution(a_block_dstr_encode));
}();
const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
constexpr auto smem_buf_size =
MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
static_assert(a_sld.get_num_of_access() == 8);
constexpr auto sld_os = generate_tuple(
[&](auto i_access) {
return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
},
number<a_sld.get_num_of_access()>{});
index_t loop_cnt = k / Block_K;
// this is the acc thread buffer
fp32x4_t v_acc[16]{.0f};
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
: [s_loop_cnt]"+s"(loop_cnt),
[v_acc_0]"+v"(v_acc[0]),
[v_acc_1]"+v"(v_acc[1]),
[v_acc_2]"+v"(v_acc[2]),
[v_acc_3]"+v"(v_acc[3]),
[v_acc_4]"+v"(v_acc[4]),
[v_acc_5]"+v"(v_acc[5]),
[v_acc_6]"+v"(v_acc[6]),
[v_acc_7]"+v"(v_acc[7]),
[v_acc_8]"+v"(v_acc[8]),
[v_acc_9]"+v"(v_acc[9]),
[v_acc_10]"+v"(v_acc[10]),
[v_acc_11]"+v"(v_acc[11]),
[v_acc_12]"+v"(v_acc[12]),
[v_acc_13]"+v"(v_acc[13]),
[v_acc_14]"+v"(v_acc[14]),
[v_acc_15]"+v"(v_acc[15]),
[s_mem_]"+r"(smem)
: [s_res_a0]"s"(res_a[0]),
[s_res_a1]"s"(res_a[1]),
[s_res_a2]"s"(res_a[2]),
[s_res_a3]"s"(res_a[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))),
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))),
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))),
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))),
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),
[s_m0_init]"s"(m0_init_value),
[s_size_per_issue]"s"(size_per_issue),
[smem_sz]"n"(smem_buf_size), //(smem_buf_size),
[sld_os_0]"n"(sld_os[number<0>{}].value),
[sld_os_1]"n"(sld_os[number<1>{}].value),
[sld_os_2]"n"(sld_os[number<2>{}].value),
[sld_os_3]"n"(sld_os[number<3>{}].value),
[sld_os_4]"n"(sld_os[number<4>{}].value),
[sld_os_5]"n"(sld_os[number<5>{}].value),
[sld_os_6]"n"(sld_os[number<6>{}].value),
[sld_os_7]"n"(sld_os[number<7>{}].value),
[s_tile_os_a]"s"(tile_offset_a_bytes),
[s_tile_os_b]"s"(tile_offset_b_bytes)
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
"s86", // s86 as tmp
"v64", "v65", "v66", "v67", "v68", "v69",
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79",
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89",
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99",
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107",
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115",
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123",
"v124", "v125", "v126", "v127"
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto c = MakeCBlockTile();
for(auto i = 0; i < 16; i++)
{
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
}
return c;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,562 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
namespace ck_tile {
// "S"tream update output along "N"
// A in smem, B load from global
// require 4 wave, occupancy=1c
struct FlatmmSn_32x128x512_1x4x1_16x16x32_Base
{
static constexpr index_t Block_M = 32;
static constexpr index_t Block_N = 128;
static constexpr index_t Block_K = 512;
static constexpr index_t WarpPerBlock_M = 1;
static constexpr index_t WarpPerBlock_N = 4;
static constexpr index_t WarpPerBlock_K = 1;
static constexpr index_t Warp_M = 16;
static constexpr index_t Warp_N = 16;
static constexpr index_t Warp_K = 32;
static constexpr index_t BlockSize = 256;
// static constexpr index_t KPack = 2; // this is used to gurantee every threads can do dwordx4
// TODO: note Nr/Kr/W need consider KPack
static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 2
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 16
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<2, 1>, // !! note here is different
sequence<0, 0>>{};
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution;
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
return 2 * 2 * 4 * 4 * (16 * 4 + 4) * sizeof(bf16_t);
}
};
struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16 : public FlatmmSn_32x128x512_1x4x1_16x16x32_Base
{
using BDataType = bf16_t;
using ODataType = bf16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template <typename BRes,
typename BCoords,
typename ORes,
typename OCoords,
typename OFlags,
typename ScaleTensor>
CK_TILE_DEVICE auto
operator()(const BRes& res_b,
const BCoords& cached_coords_b,
const ORes& res_o,
const OCoords& cached_coords_o,
const OFlags& o_flags, // this should be in sgpr
CK_TILE_LDS_ADDR void* smem,
index_t n, // loop along n dim
const ScaleTensor& scale_,
index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
index_t tile_offset_o)
{
static_assert(BCoords::size() == 8); // 8
static_assert(OCoords::size() == 8);
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
static_assert(ScaleTensor::size() == 2);
float s0 = scale_[number<0>{}];
float s1 = scale_[number<1>{}];
index_t loop_cnt = n / Block_N;
register float v_c0 asm("v64");
register float v_c1 asm("v65");
register float v_c2 asm("v66");
register float v_c3 asm("v67");
register float v_c4 asm("v68");
register float v_c5 asm("v69");
register float v_c6 asm("v70");
register float v_c7 asm("v71");
register float v_c8 asm("v72");
register float v_c9 asm("v73");
register float v_c10 asm("v74");
register float v_c11 asm("v75");
register float v_c12 asm("v76");
register float v_c13 asm("v77");
register float v_c14 asm("v78");
register float v_c15 asm("v79");
register float v_c16 asm("v80");
register float v_c17 asm("v81");
register float v_c18 asm("v82");
register float v_c19 asm("v83");
register float v_c20 asm("v84");
register float v_c21 asm("v85");
register float v_c22 asm("v86");
register float v_c23 asm("v87");
register float v_c24 asm("v88");
register float v_c25 asm("v89");
register float v_c26 asm("v90");
register float v_c27 asm("v91");
register float v_c28 asm("v92");
register float v_c29 asm("v93");
register float v_c30 asm("v94");
register float v_c31 asm("v95");
int32_t nan_hi = 0x7fff0000;
int32_t nan_lo = 0x00007fff;
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
// every threads need 8xK in contiguous register
// ... and every wave need the same data
int lane_id = threadIdx.x % 64;
int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
sld_y_os *= 2;
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
// sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
sfl_sst *= 2;
// from LDS we need load as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
// sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
sfl_sld *= 2;
// B nr->kr
// clang-format off
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(loop_cnt),
[c0]"+v" (v_c0),
[c1]"+v" (v_c1),
[c2]"+v" (v_c2),
[c3]"+v" (v_c3),
[c4]"+v" (v_c4),
[c5]"+v" (v_c5),
[c6]"+v" (v_c6),
[c7]"+v" (v_c7),
[c8]"+v" (v_c8),
[c9]"+v" (v_c9),
[c10]"+v"(v_c10),
[c11]"+v"(v_c11),
[c12]"+v"(v_c12),
[c13]"+v"(v_c13),
[c14]"+v"(v_c14),
[c15]"+v"(v_c15),
[c16]"+v"(v_c16),
[c17]"+v"(v_c17),
[c18]"+v"(v_c18),
[c19]"+v"(v_c19),
[c20]"+v"(v_c20),
[c21]"+v"(v_c21),
[c22]"+v"(v_c22),
[c23]"+v"(v_c23),
[c24]"+v"(v_c24),
[c25]"+v"(v_c25),
[c26]"+v"(v_c26),
[c27]"+v"(v_c27),
[c28]"+v"(v_c28),
[c29]"+v"(v_c29),
[c30]"+v"(v_c30),
[c31]"+v"(v_c31)
:
[sld_a_base]"n"(0),
[shfl_base]"n"(0),
[v_sld_y_os]"v"(sld_y_os),
[v_sfl_sld]"v"(sfl_sld),
[v_sfl_sst]"v"(sfl_sst),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
[v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
[v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
[v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
[v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
[v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
[scale_0]"v"(s0),
[scale_1]"v"(s1),
[v_nan_lo]"v"(nan_lo),
[v_nan_hi]"v"(nan_hi),
[s_execflag_0]"s"(o_flags[number<0>{}]),
[s_execflag_1]"s"(o_flags[number<1>{}]),
[s_execflag_2]"s"(o_flags[number<2>{}]),
[s_execflag_3]"s"(o_flags[number<3>{}]),
[s_execflag_4]"s"(o_flags[number<4>{}]),
[s_execflag_5]"s"(o_flags[number<5>{}]),
[s_execflag_6]"s"(o_flags[number<6>{}]),
[s_execflag_7]"s"(o_flags[number<7>{}])
:
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
"s36", "s37",
"v50", "v54", "v55",
"v64","v65","v66","v67","v68","v69","v70","v71",
"v72","v73","v74","v75","v76","v77","v78","v79",
"v80","v81","v82","v83","v84","v85","v86","v87",
"v88","v89","v90","v91","v92","v93","v94","v95",
"v128", "v129", "v130", "v131",
"v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
"v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
"v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
"v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
"v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
"v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
"v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
"v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
"v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
"v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
"v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
"v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
"v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
"v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
"v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
"v252", "v253", "v254", "v255"
);
#pragma clang diagnostic pop
// clang-format on
}
};
struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16 : public FlatmmSn_32x128x512_1x4x1_16x16x32_Base
{
using BDataType = bf16_t;
using ODataType = bf16_t;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
template <typename BRes,
typename BCoords,
typename ORes,
typename OCoords,
typename OFlags,
typename ScaleTensor>
CK_TILE_DEVICE auto
operator()(const BRes& res_b,
const BCoords& cached_coords_b,
const ORes& res_o,
const OCoords& cached_coords_o,
const OFlags& o_flags, // this should be in sgpr
CK_TILE_LDS_ADDR void* smem,
index_t n, // loop along n dim
const ScaleTensor& scale_,
index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
index_t tile_offset_o)
{
static_assert(BCoords::size() == 8); // 8
static_assert(OCoords::size() == 8);
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
static_assert(ScaleTensor::size() == 2);
float s0 = scale_[number<0>{}];
float s1 = scale_[number<1>{}];
index_t loop_cnt = n / Block_N;
register float v_c0 asm("v64");
register float v_c1 asm("v65");
register float v_c2 asm("v66");
register float v_c3 asm("v67");
register float v_c4 asm("v68");
register float v_c5 asm("v69");
register float v_c6 asm("v70");
register float v_c7 asm("v71");
register float v_c8 asm("v72");
register float v_c9 asm("v73");
register float v_c10 asm("v74");
register float v_c11 asm("v75");
register float v_c12 asm("v76");
register float v_c13 asm("v77");
register float v_c14 asm("v78");
register float v_c15 asm("v79");
register float v_c16 asm("v80");
register float v_c17 asm("v81");
register float v_c18 asm("v82");
register float v_c19 asm("v83");
register float v_c20 asm("v84");
register float v_c21 asm("v85");
register float v_c22 asm("v86");
register float v_c23 asm("v87");
register float v_c24 asm("v88");
register float v_c25 asm("v89");
register float v_c26 asm("v90");
register float v_c27 asm("v91");
register float v_c28 asm("v92");
register float v_c29 asm("v93");
register float v_c30 asm("v94");
register float v_c31 asm("v95");
int32_t nan_hi = 0x7fff0000;
int32_t nan_lo = 0x00007fff;
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
// every threads need 8xK in contiguous register
// ... and every wave need the same data
int lane_id = threadIdx.x % 64;
int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
sld_y_os *= 2;
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
// sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
sfl_sst *= 2;
// from LDS we need load as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
// sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
sfl_sld *= 2;
// B nr->kr
// clang-format off
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
asm volatile(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:[smem_]"+r"(smem),
[s_loop_cnt]"+s"(loop_cnt),
[c0]"+v" (v_c0),
[c1]"+v" (v_c1),
[c2]"+v" (v_c2),
[c3]"+v" (v_c3),
[c4]"+v" (v_c4),
[c5]"+v" (v_c5),
[c6]"+v" (v_c6),
[c7]"+v" (v_c7),
[c8]"+v" (v_c8),
[c9]"+v" (v_c9),
[c10]"+v"(v_c10),
[c11]"+v"(v_c11),
[c12]"+v"(v_c12),
[c13]"+v"(v_c13),
[c14]"+v"(v_c14),
[c15]"+v"(v_c15),
[c16]"+v"(v_c16),
[c17]"+v"(v_c17),
[c18]"+v"(v_c18),
[c19]"+v"(v_c19),
[c20]"+v"(v_c20),
[c21]"+v"(v_c21),
[c22]"+v"(v_c22),
[c23]"+v"(v_c23),
[c24]"+v"(v_c24),
[c25]"+v"(v_c25),
[c26]"+v"(v_c26),
[c27]"+v"(v_c27),
[c28]"+v"(v_c28),
[c29]"+v"(v_c29),
[c30]"+v"(v_c30),
[c31]"+v"(v_c31)
:
[sld_a_base]"n"(0),
[shfl_base]"n"(0),
[v_sld_y_os]"v"(sld_y_os),
[v_sfl_sld]"v"(sfl_sld),
[v_sfl_sst]"v"(sfl_sst),
[s_res_o0]"s"(res_o[0]),
[s_res_o1]"s"(res_o[1]),
//[s_res_o2]"s"(res_o[2]),
//[s_res_o3]"s"(res_o[3]),
[s_res_b0]"s"(res_b[0]),
[s_res_b1]"s"(res_b[1]),
[s_res_b2]"s"(res_b[2]),
[s_res_b3]"s"(res_b[3]),
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
[v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
[v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
[v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
[v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
[v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
[scale_0]"v"(s0),
[scale_1]"v"(s1),
[v_nan_lo]"v"(nan_lo),
[v_nan_hi]"v"(nan_hi),
[s_execflag_0]"s"(o_flags[number<0>{}]),
[s_execflag_1]"s"(o_flags[number<1>{}]),
[s_execflag_2]"s"(o_flags[number<2>{}]),
[s_execflag_3]"s"(o_flags[number<3>{}]),
[s_execflag_4]"s"(o_flags[number<4>{}]),
[s_execflag_5]"s"(o_flags[number<5>{}]),
[s_execflag_6]"s"(o_flags[number<6>{}]),
[s_execflag_7]"s"(o_flags[number<7>{}])
:
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
"s36", "s37",
"v50", "v54", "v55",
"v64","v65","v66","v67","v68","v69","v70","v71",
"v72","v73","v74","v75","v76","v77","v78","v79",
"v80","v81","v82","v83","v84","v85","v86","v87",
"v88","v89","v90","v91","v92","v93","v94","v95",
"v128", "v129", "v130", "v131",
"v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
"v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
"v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
"v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
"v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
"v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
"v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
"v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
"v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
"v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
"v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
"v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
"v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
"v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
"v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
"v252", "v253", "v254", "v255"
);
#pragma clang diagnostic pop
// clang-format on
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,10 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#define CK_TILE_FLATMM_UK_MFMA_FP16 0
#define CK_TILE_FLATMM_UK_MFMA_BF16 1
#define CK_TILE_FLATMM_UK_MFMA_INT8 2
#define CK_TILE_FLATMM_UK_MFMA_FP8 3
#define CK_TILE_FLATMM_UK_MFMA_BF8 4

View File

@@ -0,0 +1 @@
the files under this folder should not be included directly!

View File

@@ -0,0 +1,613 @@
#ifndef CK_TILE_FLATMM_UK_MFMA
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#endif
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_BF16
# define _UK_MFMA_ "v_mfma_f32_16x16x16_bf16"
# define _UK_PK_CVT_(x0_, x1_, y_) \
" v_cmp_u_f32 s[36:37], " x0_ ", " x0_ " \n" \
" v_add3_u32 v50, " x0_ ", %[v_nan_lo], 1 \n" \
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[36:37] \n" \
" v_cmp_u_f32 s[36:37], " x1_ ", " x1_ " \n" \
" v_add3_u32 v50, " x1_ ", %[v_nan_lo], 1 \n" \
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[36:37] \n" \
" v_perm_b32 " y_ ", v55, v54, s52 \n"
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_bf16"
#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16"
# define _UK_PK_CVT_(x0_, x1_, y_) \
" v_cvt_f16_f32 v54, " x0_ " \n" \
" v_cvt_f16_f32 v55, " x1_ " \n" \
" v_pack_b32_f16 " y_ ", v54, v55 \n"
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
#endif
";-------------------------------------------------------------\n"
" s_mov_b32 s52, 0x07060302 ; v_perm\n"
" s_mov_b64 s[38:39], exec ; save current exec\n"
" s_mov_b32 s8, %[s_res_o0] \n"
" s_mov_b32 s9, %[s_res_o1] \n"
" s_mov_b32 s12, %[s_res_b0] \n"
" s_mov_b32 s13, %[s_res_b1] \n"
" s_mov_b32 s14, %[s_res_b2] \n"
" s_mov_b32 s15, %[s_res_b3] \n"
" ds_read_b64 v[128:129], %[v_sld_y_os] offset:0 + %[sld_a_base] \n"
" ds_read_b64 v[130:131], %[v_sld_y_os] offset:128 + %[sld_a_base] \n"
" ds_read_b64 v[132:133], %[v_sld_y_os] offset:1024 + %[sld_a_base] \n"
" ds_read_b64 v[134:135], %[v_sld_y_os] offset:1152 + %[sld_a_base] \n"
" ds_read_b64 v[136:137], %[v_sld_y_os] offset:2048 + %[sld_a_base] \n"
" ds_read_b64 v[138:139], %[v_sld_y_os] offset:2176 + %[sld_a_base] \n"
" ds_read_b64 v[140:141], %[v_sld_y_os] offset:3072 + %[sld_a_base] \n"
" ds_read_b64 v[142:143], %[v_sld_y_os] offset:3200 + %[sld_a_base] \n"
" ds_read_b64 v[144:145], %[v_sld_y_os] offset:4096 + %[sld_a_base] \n"
" ds_read_b64 v[146:147], %[v_sld_y_os] offset:4224 + %[sld_a_base] \n"
" ds_read_b64 v[148:149], %[v_sld_y_os] offset:5120 + %[sld_a_base] \n"
" ds_read_b64 v[150:151], %[v_sld_y_os] offset:5248 + %[sld_a_base] \n"
" ds_read_b64 v[152:153], %[v_sld_y_os] offset:6144 + %[sld_a_base] \n"
" ds_read_b64 v[154:155], %[v_sld_y_os] offset:6272 + %[sld_a_base] \n"
" ds_read_b64 v[156:157], %[v_sld_y_os] offset:7168 + %[sld_a_base] \n"
" ds_read_b64 v[158:159], %[v_sld_y_os] offset:7296 + %[sld_a_base] \n"
" ds_read_b64 v[160:161], %[v_sld_y_os] offset:8192 + %[sld_a_base] \n"
" ds_read_b64 v[162:163], %[v_sld_y_os] offset:8320 + %[sld_a_base] \n"
" ds_read_b64 v[164:165], %[v_sld_y_os] offset:9216 + %[sld_a_base] \n"
" ds_read_b64 v[166:167], %[v_sld_y_os] offset:9344 + %[sld_a_base] \n"
" ds_read_b64 v[168:169], %[v_sld_y_os] offset:10240 + %[sld_a_base] \n"
" ds_read_b64 v[170:171], %[v_sld_y_os] offset:10368 + %[sld_a_base] \n"
" ds_read_b64 v[172:173], %[v_sld_y_os] offset:11264 + %[sld_a_base] \n"
" ds_read_b64 v[174:175], %[v_sld_y_os] offset:11392 + %[sld_a_base] \n"
" ds_read_b64 v[176:177], %[v_sld_y_os] offset:12288 + %[sld_a_base] \n"
" ds_read_b64 v[178:179], %[v_sld_y_os] offset:12416 + %[sld_a_base] \n"
" ds_read_b64 v[180:181], %[v_sld_y_os] offset:13312 + %[sld_a_base] \n"
" ds_read_b64 v[182:183], %[v_sld_y_os] offset:13440 + %[sld_a_base] \n"
" ds_read_b64 v[184:185], %[v_sld_y_os] offset:14336 + %[sld_a_base] \n"
" ds_read_b64 v[186:187], %[v_sld_y_os] offset:14464 + %[sld_a_base] \n"
" ds_read_b64 v[188:189], %[v_sld_y_os] offset:15360 + %[sld_a_base] \n"
" ds_read_b64 v[190:191], %[v_sld_y_os] offset:15488 + %[sld_a_base] \n"
" ds_read_b64 v[192:193], %[v_sld_y_os] offset:16384 + %[sld_a_base] \n"
" ds_read_b64 v[194:195], %[v_sld_y_os] offset:16512 + %[sld_a_base] \n"
" ds_read_b64 v[196:197], %[v_sld_y_os] offset:17408 + %[sld_a_base] \n"
" ds_read_b64 v[198:199], %[v_sld_y_os] offset:17536 + %[sld_a_base] \n"
" ds_read_b64 v[200:201], %[v_sld_y_os] offset:18432 + %[sld_a_base] \n"
" ds_read_b64 v[202:203], %[v_sld_y_os] offset:18560 + %[sld_a_base] \n"
" ds_read_b64 v[204:205], %[v_sld_y_os] offset:19456 + %[sld_a_base] \n"
" ds_read_b64 v[206:207], %[v_sld_y_os] offset:19584 + %[sld_a_base] \n"
" ds_read_b64 v[208:209], %[v_sld_y_os] offset:20480 + %[sld_a_base] \n"
" ds_read_b64 v[210:211], %[v_sld_y_os] offset:20608 + %[sld_a_base] \n"
" ds_read_b64 v[212:213], %[v_sld_y_os] offset:21504 + %[sld_a_base] \n"
" ds_read_b64 v[214:215], %[v_sld_y_os] offset:21632 + %[sld_a_base] \n"
" ds_read_b64 v[216:217], %[v_sld_y_os] offset:22528 + %[sld_a_base] \n"
" ds_read_b64 v[218:219], %[v_sld_y_os] offset:22656 + %[sld_a_base] \n"
" ds_read_b64 v[220:221], %[v_sld_y_os] offset:23552 + %[sld_a_base] \n"
" ds_read_b64 v[222:223], %[v_sld_y_os] offset:23680 + %[sld_a_base] \n"
" ds_read_b64 v[224:225], %[v_sld_y_os] offset:24576 + %[sld_a_base] \n"
" ds_read_b64 v[226:227], %[v_sld_y_os] offset:24704 + %[sld_a_base] \n"
" ds_read_b64 v[228:229], %[v_sld_y_os] offset:25600 + %[sld_a_base] \n"
" ds_read_b64 v[230:231], %[v_sld_y_os] offset:25728 + %[sld_a_base] \n"
" ds_read_b64 v[232:233], %[v_sld_y_os] offset:26624 + %[sld_a_base] \n"
" ds_read_b64 v[234:235], %[v_sld_y_os] offset:26752 + %[sld_a_base] \n"
" ds_read_b64 v[236:237], %[v_sld_y_os] offset:27648 + %[sld_a_base] \n"
" ds_read_b64 v[238:239], %[v_sld_y_os] offset:27776 + %[sld_a_base] \n"
" ds_read_b64 v[240:241], %[v_sld_y_os] offset:28672 + %[sld_a_base] \n"
" ds_read_b64 v[242:243], %[v_sld_y_os] offset:28800 + %[sld_a_base] \n"
" ds_read_b64 v[244:245], %[v_sld_y_os] offset:29696 + %[sld_a_base] \n"
" ds_read_b64 v[246:247], %[v_sld_y_os] offset:29824 + %[sld_a_base] \n"
" ds_read_b64 v[248:249], %[v_sld_y_os] offset:30720 + %[sld_a_base] \n"
" ds_read_b64 v[250:251], %[v_sld_y_os] offset:30848 + %[sld_a_base] \n"
" ds_read_b64 v[252:253], %[v_sld_y_os] offset:31744 + %[sld_a_base] \n"
" ds_read_b64 v[254:255], %[v_sld_y_os] offset:31872 + %[sld_a_base] \n"
" s_waitcnt 0 \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen offset:1024 \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_waitcnt 0 \n"
"L_start%=: \n"
" s_waitcnt vmcnt(32) \n"
" s_barrier \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[0:1], v[128:129], 0 \n"
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[2:3], v[130:131], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[4:5], v[132:133], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[6:7], v[134:135], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[8:9], v[136:137], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[10:11], v[138:139], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[12:13], v[140:141], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[14:15], v[142:143], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[0:1], v[192:193], 0 \n"
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[2:3], v[194:195], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[4:5], v[196:197], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[6:7], v[198:199], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[8:9], v[200:201], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[10:11], v[202:203], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[12:13], v[204:205], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[14:15], v[206:207], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[16:17], v[128:129], 0 \n"
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[18:19], v[130:131], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[20:21], v[132:133], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[22:23], v[134:135], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[24:25], v[136:137], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[26:27], v[138:139], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[28:29], v[140:141], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[30:31], v[142:143], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[16:17], v[192:193], 0 \n"
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[18:19], v[194:195], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[20:21], v[196:197], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[22:23], v[198:199], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[24:25], v[200:201], [%[c12], %[c13], %[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[26:27], v[202:203], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[28:29], v[204:205], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[30:31], v[206:207], [%[c12], %[c13], %[c14], %[c15]] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[32:33], v[144:145], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[34:35], v[146:147], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[36:37], v[148:149], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[38:39], v[150:151], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[40:41], v[152:153], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[42:43], v[154:155], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[44:45], v[156:157], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[46:47], v[158:159], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[32:33], v[208:209], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[34:35], v[210:211], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[36:37], v[212:213], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[38:39], v[214:215], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[40:41], v[216:217], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[42:43], v[218:219], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[44:45], v[220:221], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[46:47], v[222:223], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[48:49], v[144:145], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[50:51], v[146:147], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[52:53], v[148:149], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[54:55], v[150:151], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[56:57], v[152:153], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[58:59], v[154:155], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[60:61], v[156:157], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[62:63], v[158:159], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[48:49], v[208:209], [%[c12], %[c13], %[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[50:51], v[210:211], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[52:53], v[212:213], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[54:55], v[214:215], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[56:57], v[216:217], [%[c12], %[c13], %[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[58:59], v[218:219], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[60:61], v[220:221], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[62:63], v[222:223], [%[c12], %[c13], %[c14], %[c15]] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[64:65], v[160:161], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[66:67], v[162:163], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[68:69], v[164:165], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[70:71], v[166:167], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[72:73], v[168:169], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[74:75], v[170:171], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[76:77], v[172:173], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[78:79], v[174:175], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[64:65], v[224:225], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[66:67], v[226:227], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[68:69], v[228:229], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[70:71], v[230:231], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[72:73], v[232:233], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[74:75], v[234:235], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[76:77], v[236:237], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[78:79], v[238:239], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[80:81], v[160:161], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[82:83], v[162:163], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[84:85], v[164:165], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[86:87], v[166:167], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[88:89], v[168:169], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[90:91], v[170:171], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[92:93], v[172:173], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[94:95], v[174:175], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[80:81], v[224:225], [%[c12], %[c13], %[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[82:83], v[226:227], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[84:85], v[228:229], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[86:87], v[230:231], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[88:89], v[232:233], [%[c12], %[c13], %[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[90:91], v[234:235], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[92:93], v[236:237], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[94:95], v[238:239], [%[c12], %[c13], %[c14], %[c15]] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[96:97], v[176:177], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[98:99], v[178:179], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[100:101], v[180:181], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[102:103], v[182:183], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[104:105], v[184:185], [%[c0], %[c1], %[c2], %[c3]] \n"
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[106:107], v[186:187], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[108:109], v[188:189], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[110:111], v[190:191], [%[c0], %[c1], %[c2], %[c3]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[96:97], v[240:241], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[98:99], v[242:243], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[100:101], v[244:245], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[102:103], v[246:247], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[104:105], v[248:249], [%[c4], %[c5], %[c6], %[c7]] \n"
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[106:107], v[250:251], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[108:109], v[252:253], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[110:111], v[254:255], [%[c4], %[c5], %[c6], %[c7]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[112:113], v[176:177], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[114:115], v[178:179], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[116:117], v[180:181], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[118:119], v[182:183], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[120:121], v[184:185], [%[c8], %[c9], %[c10], %[c11]] \n"
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[122:123], v[186:187], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[124:125], v[188:189], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[126:127], v[190:191], [%[c8], %[c9], %[c10], %[c11]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[112:113], v[240:241], [%[c12], %[c13], %[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[114:115], v[242:243], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[116:117], v[244:245], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[118:119], v[246:247], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[120:121], v[248:249], [%[c12], %[c13], %[c14], %[c15]] \n"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[122:123], v[250:251], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[124:125], v[252:253], [%[c12], %[c13], %[c14], %[c15]] \n"
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[126:127], v[254:255], [%[c12], %[c13], %[c14], %[c15]]\n"
" v_mul_f32 %[c0], %[scale_0], %[c0] \n"
" v_mul_f32 %[c1], %[scale_0], %[c1] \n"
" v_mul_f32 %[c2], %[scale_0], %[c2] \n"
" v_mul_f32 %[c3], %[scale_0], %[c3] \n"
" v_mul_f32 %[c4], %[scale_1], %[c4] \n"
" v_mul_f32 %[c5], %[scale_1], %[c5] \n"
" v_mul_f32 %[c6], %[scale_1], %[c6] \n"
" v_mul_f32 %[c7], %[scale_1], %[c7] \n"
" v_mul_f32 %[c8], %[scale_0], %[c8] \n"
" v_mul_f32 %[c9], %[scale_0], %[c9] \n"
" v_mul_f32 %[c10], %[scale_0], %[c10] \n"
" v_mul_f32 %[c11], %[scale_0], %[c11] \n"
" v_mul_f32 %[c12], %[scale_1], %[c12] \n"
" v_mul_f32 %[c13], %[scale_1], %[c13] \n"
" v_mul_f32 %[c14], %[scale_1], %[c14] \n"
" v_mul_f32 %[c15], %[scale_1], %[c15] \n"
_UK_PK_CVT_("%[c0]", "%[c1]", "%[c0]")
_UK_PK_CVT_("%[c2]", "%[c3]", "%[c1]")
_UK_PK_CVT_("%[c4]", "%[c5]", "%[c2]")
_UK_PK_CVT_("%[c6]", "%[c7]", "%[c3]")
_UK_PK_CVT_("%[c8]", "%[c9]", "%[c4]")
_UK_PK_CVT_("%[c10]", "%[c11]", "%[c5]")
_UK_PK_CVT_("%[c12]", "%[c13]", "%[c6]")
_UK_PK_CVT_("%[c14]", "%[c15]", "%[c7]")
" ;------------------------------ \n"
" ds_write_b64 %[v_sfl_sst], [%[c0],%[c1]] offset:0 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c2],%[c3]] offset:4352 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c4],%[c5]] offset:2176 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c6],%[c7]] offset:6528 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
" s_barrier \n"
" ds_read_b32 %[c0], %[v_sfl_sld] offset:0 + %[shfl_base] \n"
" ds_read_b32 %[c1], %[v_sfl_sld] offset:32 + %[shfl_base] \n"
" ds_read_b32 %[c2], %[v_sfl_sld] offset:64 + %[shfl_base] \n"
" ds_read_b32 %[c3], %[v_sfl_sld] offset:96 + %[shfl_base] \n"
" ds_read_b32 %[c4], %[v_sfl_sld] offset:4352 + %[shfl_base] \n"
" ds_read_b32 %[c5], %[v_sfl_sld] offset:4384 + %[shfl_base] \n"
" ds_read_b32 %[c6], %[v_sfl_sld] offset:4416 + %[shfl_base] \n"
" ds_read_b32 %[c7], %[v_sfl_sld] offset:4448 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
_UK_ATOMIC_ADD_ " %[v_os_o0], %[c0], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
_UK_ATOMIC_ADD_ " %[v_os_o1], %[c1], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
_UK_ATOMIC_ADD_ " %[v_os_o2], %[c2], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
_UK_ATOMIC_ADD_ " %[v_os_o3], %[c3], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
_UK_ATOMIC_ADD_ " %[v_os_o4], %[c4], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
_UK_ATOMIC_ADD_ " %[v_os_o5], %[c5], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
_UK_ATOMIC_ADD_ " %[v_os_o6], %[c6], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
_UK_ATOMIC_ADD_ " %[v_os_o7], %[c7], s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k-- \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_add_u32 s8, %[s_tile_os_o], s8 \n"
" s_addc_u32 s9, 0, s9 \n"
" s_waitcnt vmcnt(32) \n"
" s_barrier \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[128:129], v[128:129], 0 \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[130:131], v[130:131], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[132:133], v[132:133], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[134:135], v[134:135], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[136:137], v[136:137], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[138:139], v[138:139], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[140:141], v[140:141], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[142:143], v[142:143], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[128:129], v[192:193], 0 \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[130:131], v[194:195], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[132:133], v[196:197], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[134:135], v[198:199], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[136:137], v[200:201], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[138:139], v[202:203], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[140:141], v[204:205], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[142:143], v[206:207], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[144:145], v[128:129], 0 \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[146:147], v[130:131], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[148:149], v[132:133], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[150:151], v[134:135], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[152:153], v[136:137], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[154:155], v[138:139], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[156:157], v[140:141], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[158:159], v[142:143], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[144:145], v[192:193], 0 \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[146:147], v[194:195], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[148:149], v[196:197], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[150:151], v[198:199], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[152:153], v[200:201], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[154:155], v[202:203], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[156:157], v[204:205], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[158:159], v[206:207], [%[c28],%[c29],%[c30],%[c31]] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[160:161], v[144:145], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[162:163], v[146:147], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[164:165], v[148:149], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[166:167], v[150:151], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[168:169], v[152:153], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[170:171], v[154:155], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[172:173], v[156:157], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[174:175], v[158:159], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[160:161], v[208:209], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[162:163], v[210:211], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[164:165], v[212:213], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[166:167], v[214:215], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[168:169], v[216:217], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[170:171], v[218:219], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[172:173], v[220:221], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[174:175], v[222:223], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[176:177], v[144:145], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[178:179], v[146:147], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[180:181], v[148:149], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[182:183], v[150:151], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[184:185], v[152:153], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[186:187], v[154:155], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[188:189], v[156:157], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[190:191], v[158:159], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[176:177], v[208:209], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[178:179], v[210:211], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[180:181], v[212:213], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[182:183], v[214:215], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[184:185], v[216:217], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[186:187], v[218:219], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[188:189], v[220:221], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[190:191], v[222:223], [%[c28],%[c29],%[c30],%[c31]] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[192:193], v[160:161], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[194:195], v[162:163], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[196:197], v[164:165], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[198:199], v[166:167], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[200:201], v[168:169], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[202:203], v[170:171], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[204:205], v[172:173], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[206:207], v[174:175], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[192:193], v[224:225], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[194:195], v[226:227], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[196:197], v[228:229], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[198:199], v[230:231], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[200:201], v[232:233], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[202:203], v[234:235], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[204:205], v[236:237], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[206:207], v[238:239], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[208:209], v[160:161], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[210:211], v[162:163], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[212:213], v[164:165], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[214:215], v[166:167], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[216:217], v[168:169], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[218:219], v[170:171], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[220:221], v[172:173], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[222:223], v[174:175], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[208:209], v[224:225], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[210:211], v[226:227], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[212:213], v[228:229], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[214:215], v[230:231], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[216:217], v[232:233], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[218:219], v[234:235], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[220:221], v[236:237], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[222:223], v[238:239], [%[c28],%[c29],%[c30],%[c31]] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[224:225], v[176:177], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[226:227], v[178:179], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[228:229], v[180:181], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[230:231], v[182:183], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[232:233], v[184:185], [%[c16],%[c17],%[c18],%[c19]] \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[234:235], v[186:187], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[236:237], v[188:189], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[238:239], v[190:191], [%[c16],%[c17],%[c18],%[c19]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[224:225], v[240:241], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[226:227], v[242:243], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[228:229], v[244:245], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[230:231], v[246:247], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[232:233], v[248:249], [%[c20],%[c21],%[c22],%[c23]] \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[234:235], v[250:251], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[236:237], v[252:253], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[238:239], v[254:255], [%[c20],%[c21],%[c22],%[c23]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[240:241], v[176:177], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[242:243], v[178:179], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[244:245], v[180:181], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[246:247], v[182:183], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[248:249], v[184:185], [%[c24],%[c25],%[c26],%[c27]] \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen offset:1024 \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[250:251], v[186:187], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[252:253], v[188:189], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[254:255], v[190:191], [%[c24],%[c25],%[c26],%[c27]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[240:241], v[240:241], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[242:243], v[242:243], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[244:245], v[244:245], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[246:247], v[246:247], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[248:249], v[248:249], [%[c28],%[c29],%[c30],%[c31]] \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[250:251], v[250:251], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[252:253], v[252:253], [%[c28],%[c29],%[c30],%[c31]] \n"
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[254:255], v[254:255], [%[c28],%[c29],%[c30],%[c31]]\n"
" v_mul_f32 %[c16], %[scale_0], %[c16] \n"
" v_mul_f32 %[c17], %[scale_0], %[c17] \n"
" v_mul_f32 %[c18], %[scale_0], %[c18] \n"
" v_mul_f32 %[c19], %[scale_0], %[c19] \n"
" v_mul_f32 %[c20], %[scale_1], %[c20] \n"
" v_mul_f32 %[c21], %[scale_1], %[c21] \n"
" v_mul_f32 %[c22], %[scale_1], %[c22] \n"
" v_mul_f32 %[c23], %[scale_1], %[c23] \n"
" v_mul_f32 %[c24], %[scale_0], %[c24] \n"
" v_mul_f32 %[c25], %[scale_0], %[c25] \n"
" v_mul_f32 %[c26], %[scale_0], %[c26] \n"
" v_mul_f32 %[c27], %[scale_0], %[c27] \n"
" v_mul_f32 %[c28], %[scale_1], %[c28] \n"
" v_mul_f32 %[c29], %[scale_1], %[c29] \n"
" v_mul_f32 %[c30], %[scale_1], %[c30] \n"
" v_mul_f32 %[c31], %[scale_1], %[c31] \n"
_UK_PK_CVT_("%[c16]", "%[c17]", "%[c16]")
_UK_PK_CVT_("%[c18]", "%[c19]", "%[c17]")
_UK_PK_CVT_("%[c20]", "%[c21]", "%[c18]")
_UK_PK_CVT_("%[c22]", "%[c23]", "%[c19]")
_UK_PK_CVT_("%[c24]", "%[c25]", "%[c20]")
_UK_PK_CVT_("%[c26]", "%[c27]", "%[c21]")
_UK_PK_CVT_("%[c28]", "%[c29]", "%[c22]")
_UK_PK_CVT_("%[c30]", "%[c31]", "%[c23]")
" ;------------------------------ \n"
" ds_write_b64 %[v_sfl_sst], [%[c16],%[c17]] offset:0 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c18],%[c19]] offset:4352 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c20],%[c21]] offset:2176 + %[shfl_base] \n"
" ds_write_b64 %[v_sfl_sst], [%[c22],%[c23]] offset:6528 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
" s_barrier \n"
" ds_read_b32 %[c16], %[v_sfl_sld] offset:0 + %[shfl_base] \n"
" ds_read_b32 %[c17], %[v_sfl_sld] offset:32 + %[shfl_base] \n"
" ds_read_b32 %[c18], %[v_sfl_sld] offset:64 + %[shfl_base] \n"
" ds_read_b32 %[c19], %[v_sfl_sld] offset:96 + %[shfl_base] \n"
" ds_read_b32 %[c20], %[v_sfl_sld] offset:4352 + %[shfl_base] \n"
" ds_read_b32 %[c21], %[v_sfl_sld] offset:4384 + %[shfl_base] \n"
" ds_read_b32 %[c22], %[v_sfl_sld] offset:4416 + %[shfl_base] \n"
" ds_read_b32 %[c23], %[v_sfl_sld] offset:4448 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
_UK_ATOMIC_ADD_ " %[v_os_o0], %[c16], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
_UK_ATOMIC_ADD_ " %[v_os_o1], %[c17], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
_UK_ATOMIC_ADD_ " %[v_os_o2], %[c18], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
_UK_ATOMIC_ADD_ " %[v_os_o3], %[c19], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
_UK_ATOMIC_ADD_ " %[v_os_o4], %[c20], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
_UK_ATOMIC_ADD_ " %[v_os_o5], %[c21], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
_UK_ATOMIC_ADD_ " %[v_os_o6], %[c22], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
_UK_ATOMIC_ADD_ " %[v_os_o7], %[c23], s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k-- \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_add_u32 s8, %[s_tile_os_o], s8 \n"
" s_addc_u32 s9, 0, s9 \n"
" s_branch L_start%= \n"
"L_end%=: \n"
#undef _UK_MFMA_
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_

View File

@@ -0,0 +1,516 @@
#ifndef CK_TILE_FLATMM_UK_MFMA
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#endif
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_BF16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_bf16"
#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16
#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16"
#endif
"s_mov_b32 s16, %[s_res_a0] \n"
"s_mov_b32 s17, %[s_res_a1] \n"
"s_mov_b32 s18, %[s_res_a2] \n"
"s_mov_b32 s19, %[s_res_a3] \n"
"s_mov_b32 s20, %[s_res_b0] \n"
"s_mov_b32 s21, %[s_res_b1] \n"
"s_mov_b32 s22, %[s_res_b2] \n"
"s_mov_b32 s23, %[s_res_b3] \n"
// "s_nop 4\n"
"; -- prefetch A0\n"
"s_add_u32 m0, 0, %[s_m0_init] \n"
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[smem_sz], %[s_m0_init] \n"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move a with cond \n"
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n"
"s_add_u32 s16, s86, s16 ; move a with cond \n"
"s_addc_u32 s17, 0, s17 ; move a with cond \n"
"; -- prefetch A1\n"
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
"s_add_u32 m0, 0, %[s_m0_init] \n"
"s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n"
"s_add_u32 s16, s86, s16 ; move a with cond \n"
"s_addc_u32 s17, 0, s17 ; move a with cond \n"
"; -- prefetch B0\n"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
"buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
"buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
"s_cselect_b32 s86, %[s_tile_os_b], 0 ; move b with cond \n"
"s_add_u32 s20, s86, s20 ; move b with cond \n"
"s_addc_u32 s21, 0, s21 ; move b with cond \n"
"s_waitcnt vmcnt(40) \n"
"s_barrier \n"
"ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]\n" // 1024: N stride, 64 K stride
"ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]\n"
"ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]\n"
"ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]\n"
"ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]\n"
"ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]\n"
"ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]\n"
"ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]\n"
"L_start%=: \n"
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n"
_UK_MFMA_ " %[v_acc_0], acc[0:1], v[64:65], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[2:3], v[66:67], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_0], acc[4:5], v[68:69], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[6:7], v[70:71], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_0], acc[8:9], v[72:73], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[10:11], v[74:75], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_0], acc[12:13], v[76:77], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[14:15], v[78:79], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_1], acc[0:1], v[80:81], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[2:3], v[82:83], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_1], acc[4:5], v[84:85], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[6:7], v[86:87], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_1], acc[8:9], v[88:89], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[10:11], v[90:91], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_1], acc[12:13], v[92:93], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[14:15], v[94:95], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_2], acc[16:17], v[64:65], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[18:19], v[66:67], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_2], acc[20:21], v[68:69], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[22:23], v[70:71], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_2], acc[24:25], v[72:73], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[26:27], v[74:75], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_2], acc[28:29], v[76:77], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[30:31], v[78:79], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_3], acc[16:17], v[80:81], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[18:19], v[82:83], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_3], acc[20:21], v[84:85], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[22:23], v[86:87], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_3], acc[24:25], v[88:89], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[26:27], v[90:91], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_3], acc[28:29], v[92:93], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[30:31], v[94:95], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[smem_sz], %[s_m0_init] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_4], acc[32:33], v[64:65], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[34:35], v[66:67], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_4], acc[36:37], v[68:69], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[38:39], v[70:71], %[v_acc_4] \n"
" ds_read_b128 v[96:99], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_0] \n"
_UK_MFMA_ " %[v_acc_4], acc[40:41], v[72:73], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[42:43], v[74:75], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_4], acc[44:45], v[76:77], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[46:47], v[78:79], %[v_acc_4] \n"
" ds_read_b128 v[100:103], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_1] \n"
_UK_MFMA_ " %[v_acc_5], acc[32:33], v[80:81], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[34:35], v[82:83], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_5], acc[36:37], v[84:85], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[38:39], v[86:87], %[v_acc_5] \n"
" ds_read_b128 v[104:107], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_2] \n"
_UK_MFMA_ " %[v_acc_5], acc[40:41], v[88:89], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[42:43], v[90:91], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_5], acc[44:45], v[92:93], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[46:47], v[94:95], %[v_acc_5] \n"
" ds_read_b128 v[108:111], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_3] \n"
_UK_MFMA_ " %[v_acc_6], acc[48:49], v[64:65], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[50:51], v[66:67], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_6], acc[52:53], v[68:69], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[54:55], v[70:71], %[v_acc_6] \n"
" ds_read_b128 v[112:115], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_4] \n"
_UK_MFMA_ " %[v_acc_6], acc[56:57], v[72:73], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[58:59], v[74:75], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_6], acc[60:61], v[76:77], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[62:63], v[78:79], %[v_acc_6] \n"
" ds_read_b128 v[116:119], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_5] \n"
_UK_MFMA_ " %[v_acc_7], acc[48:49], v[80:81], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[50:51], v[82:83], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_7], acc[52:53], v[84:85], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[54:55], v[86:87], %[v_acc_7] \n"
" ds_read_b128 v[120:123], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_6] \n"
_UK_MFMA_ " %[v_acc_7], acc[56:57], v[88:89], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[58:59], v[90:91], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_7], acc[60:61], v[92:93], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[62:63], v[94:95], %[v_acc_7] \n"
" ds_read_b128 v[124:127], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_7] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_8], acc[64:65], v[64:65], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[66:67], v[66:67], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_8], acc[68:69], v[68:69], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[70:71], v[70:71], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[72:73], v[72:73], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[74:75], v[74:75], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_8], acc[76:77], v[76:77], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[78:79], v[78:79], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_9], acc[64:65], v[80:81], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[66:67], v[82:83], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_9], acc[68:69], v[84:85], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[70:71], v[86:87], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[72:73], v[88:89], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[74:75], v[90:91], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_9], acc[76:77], v[92:93], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[78:79], v[94:95], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_10], acc[80:81], v[64:65], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[82:83], v[66:67], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_10], acc[84:85], v[68:69], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[86:87], v[70:71], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[88:89], v[72:73], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[90:91], v[74:75], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_10], acc[92:93], v[76:77], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[94:95], v[78:79], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_11], acc[80:81], v[80:81], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[82:83], v[82:83], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_11], acc[84:85], v[84:85], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[86:87], v[86:87], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[88:89], v[88:89], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[90:91], v[90:91], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_11], acc[92:93], v[92:93], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[94:95], v[94:95], %[v_acc_11] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_12], acc[96:97], v[64:65], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[98:99], v[66:67], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_12], acc[100:101], v[68:69], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[102:103], v[70:71], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[104:105], v[72:73], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[106:107], v[74:75], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_12], acc[108:109], v[76:77], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[110:111], v[78:79], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_13], acc[96:97], v[80:81], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[98:99], v[82:83], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_13], acc[100:101], v[84:85], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[102:103], v[86:87], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[104:105], v[88:89], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[106:107], v[90:91], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_13], acc[108:109], v[92:93], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[110:111], v[94:95], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_14], acc[112:113], v[64:65], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[114:115], v[66:67], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_14], acc[116:117], v[68:69], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[118:119], v[70:71], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[120:121], v[72:73], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[122:123], v[74:75], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_14], acc[124:125], v[76:77], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[126:127], v[78:79], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_15], acc[112:113], v[80:81], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[114:115], v[82:83], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_15], acc[116:117], v[84:85], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[118:119], v[86:87], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[120:121], v[88:89], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[122:123], v[90:91], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[20:23], 0 offen offset:3072\n"
_UK_MFMA_ " %[v_acc_15], acc[124:125], v[92:93], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[126:127], v[94:95], %[v_acc_15] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
" s_cselect_b32 s86, %[s_tile_os_a], 0 \n"
" s_add_u32 s16, s86, s16 \n"
" s_addc_u32 s17, 0, s17 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s20, s86, s20 \n"
" s_addc_u32 s21, 0, s21 \n"
" ;------------------------------------------ \n"
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n"
_UK_MFMA_ " %[v_acc_0], acc[128:129], v[96:97], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[130:131], v[98:99], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_0], acc[132:133], v[100:101], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[134:135], v[102:103], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_0], acc[136:137], v[104:105], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[138:139], v[106:107], %[v_acc_0] \n"
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_0], acc[140:141], v[108:109], %[v_acc_0] \n"
_UK_MFMA_ " %[v_acc_0], acc[142:143], v[110:111], %[v_acc_0] \n"
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_1], acc[128:129], v[112:113], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[130:131], v[114:115], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_1], acc[132:133], v[116:117], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[134:135], v[118:119], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_1], acc[136:137], v[120:121], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[138:139], v[122:123], %[v_acc_1] \n"
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_1], acc[140:141], v[124:125], %[v_acc_1] \n"
_UK_MFMA_ " %[v_acc_1], acc[142:143], v[126:127], %[v_acc_1] \n"
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_2], acc[144:145], v[96:97], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[146:147], v[98:99], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_2], acc[148:149], v[100:101], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[150:151], v[102:103], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_2], acc[152:153], v[104:105], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[154:155], v[106:107], %[v_acc_2] \n"
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_2], acc[156:157], v[108:109], %[v_acc_2] \n"
_UK_MFMA_ " %[v_acc_2], acc[158:159], v[110:111], %[v_acc_2] \n"
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_3], acc[144:145], v[112:113], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[146:147], v[114:115], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_3], acc[148:149], v[116:117], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[150:151], v[118:119], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
_UK_MFMA_ " %[v_acc_3], acc[152:153], v[120:121], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[154:155], v[122:123], %[v_acc_3] \n"
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_3], acc[156:157], v[124:125], %[v_acc_3] \n"
_UK_MFMA_ " %[v_acc_3], acc[158:159], v[126:127], %[v_acc_3] \n"
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
" s_add_u32 m0, 0, %[s_m0_init] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_4], acc[160:161], v[96:97], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[162:163], v[98:99], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_4], acc[164:165], v[100:101], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[166:167], v[102:103], %[v_acc_4] \n"
" ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0] \n"
_UK_MFMA_ " %[v_acc_4], acc[168:169], v[104:105], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[170:171], v[106:107], %[v_acc_4] \n"
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_4], acc[172:173], v[108:109], %[v_acc_4] \n"
_UK_MFMA_ " %[v_acc_4], acc[174:175], v[110:111], %[v_acc_4] \n"
" ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1] \n"
_UK_MFMA_ " %[v_acc_5], acc[160:161], v[112:113], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[162:163], v[114:115], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_5], acc[164:165], v[116:117], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[166:167], v[118:119], %[v_acc_5] \n"
" ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2] \n"
_UK_MFMA_ " %[v_acc_5], acc[168:169], v[120:121], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[170:171], v[122:123], %[v_acc_5] \n"
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_5], acc[172:173], v[124:125], %[v_acc_5] \n"
_UK_MFMA_ " %[v_acc_5], acc[174:175], v[126:127], %[v_acc_5] \n"
" ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3] \n"
_UK_MFMA_ " %[v_acc_6], acc[176:177], v[96:97], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[178:179], v[98:99], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_6], acc[180:181], v[100:101], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[182:183], v[102:103], %[v_acc_6] \n"
" ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4] \n"
_UK_MFMA_ " %[v_acc_6], acc[184:185], v[104:105], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[186:187], v[106:107], %[v_acc_6] \n"
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_6], acc[188:189], v[108:109], %[v_acc_6] \n"
_UK_MFMA_ " %[v_acc_6], acc[190:191], v[110:111], %[v_acc_6] \n"
" ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5] \n"
_UK_MFMA_ " %[v_acc_7], acc[176:177], v[112:113], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[178:179], v[114:115], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_7], acc[180:181], v[116:117], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[182:183], v[118:119], %[v_acc_7] \n"
" ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6] \n"
_UK_MFMA_ " %[v_acc_7], acc[184:185], v[120:121], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[186:187], v[122:123], %[v_acc_7] \n"
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_7], acc[188:189], v[124:125], %[v_acc_7] \n"
_UK_MFMA_ " %[v_acc_7], acc[190:191], v[126:127], %[v_acc_7] \n"
" ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_8], acc[192:193], v[96:97], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[194:195], v[98:99], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_8], acc[196:197], v[100:101], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[198:199], v[102:103], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[200:201], v[104:105], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[202:203], v[106:107], %[v_acc_8] \n"
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_8], acc[204:205], v[108:109], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_8], acc[206:207], v[110:111], %[v_acc_8] \n"
_UK_MFMA_ " %[v_acc_9], acc[192:193], v[112:113], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[194:195], v[114:115], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_9], acc[196:197], v[116:117], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[198:199], v[118:119], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[200:201], v[120:121], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[202:203], v[122:123], %[v_acc_9] \n"
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_9], acc[204:205], v[124:125], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_9], acc[206:207], v[126:127], %[v_acc_9] \n"
_UK_MFMA_ " %[v_acc_10], acc[208:209], v[96:97], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[210:211], v[98:99], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_10], acc[212:213], v[100:101], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[214:215], v[102:103], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[216:217], v[104:105], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[218:219], v[106:107], %[v_acc_10] \n"
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_10], acc[220:221], v[108:109], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_10], acc[222:223], v[110:111], %[v_acc_10] \n"
_UK_MFMA_ " %[v_acc_11], acc[208:209], v[112:113], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[210:211], v[114:115], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_11], acc[212:213], v[116:117], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[214:215], v[118:119], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[216:217], v[120:121], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[218:219], v[122:123], %[v_acc_11] \n"
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_11], acc[220:221], v[124:125], %[v_acc_11] \n"
_UK_MFMA_ " %[v_acc_11], acc[222:223], v[126:127], %[v_acc_11] \n"
" s_waitcnt vmcnt(32) \n"
_UK_MFMA_ " %[v_acc_12], acc[224:225], v[96:97], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[226:227], v[98:99], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_12], acc[228:229], v[100:101], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[230:231], v[102:103], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[232:233], v[104:105], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[234:235], v[106:107], %[v_acc_12] \n"
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_12], acc[236:237], v[108:109], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_12], acc[238:239], v[110:111], %[v_acc_12] \n"
_UK_MFMA_ " %[v_acc_13], acc[224:225], v[112:113], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[226:227], v[114:115], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_13], acc[228:229], v[116:117], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[230:231], v[118:119], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[232:233], v[120:121], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[234:235], v[122:123], %[v_acc_13] \n"
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_13], acc[236:237], v[124:125], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_13], acc[238:239], v[126:127], %[v_acc_13] \n"
_UK_MFMA_ " %[v_acc_14], acc[240:241], v[96:97], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[242:243], v[98:99], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n"
_UK_MFMA_ " %[v_acc_14], acc[244:245], v[100:101], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[246:247], v[102:103], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[248:249], v[104:105], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[250:251], v[106:107], %[v_acc_14] \n"
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
_UK_MFMA_ " %[v_acc_14], acc[252:253], v[108:109], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_14], acc[254:255], v[110:111], %[v_acc_14] \n"
_UK_MFMA_ " %[v_acc_15], acc[240:241], v[112:113], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[242:243], v[114:115], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
_UK_MFMA_ " %[v_acc_15], acc[244:245], v[116:117], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[246:247], v[118:119], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[248:249], v[120:121], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[250:251], v[122:123], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
_UK_MFMA_ " %[v_acc_15], acc[252:253], v[124:125], %[v_acc_15] \n"
_UK_MFMA_ " %[v_acc_15], acc[254:255], v[126:127], %[v_acc_15] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
" s_cselect_b32 s86, %[s_tile_os_a], 0 \n"
" s_add_u32 s16, s86, s16 \n"
" s_addc_u32 s17, 0, s17 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s20, s86, s20 \n"
" s_addc_u32 s21, 0, s21 \n"
" s_branch L_start%= \n"
"L_end%=: \n"
" s_nop 2 \n"
#undef _UK_MFMA_

View File

@@ -331,7 +331,8 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
@@ -355,6 +356,7 @@ struct BlockFmhaPipelineQRKSVSAsync
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
if constexpr(i_k0 < k0_loops - 1)
@@ -386,7 +388,7 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
auto v_buf = load_tile(v_dram_window, bool_constant<false>{});
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(s_acc,
@@ -514,7 +516,8 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window(
v_dram_window,
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(v_dram_window, bool_constant<false>{}); // load next v_buf
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
}
__builtin_amdgcn_sched_barrier(0);
@@ -618,7 +621,8 @@ struct BlockFmhaPipelineQRKSVSAsync
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
{
v_buf = load_tile(v_dram_window, bool_constant<false>{}); // load next v_buf
v_buf = load_tile(
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
}
block_sync_lds();
gemm_1(o_acc,
@@ -665,8 +669,11 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier();
async_load_tile_raw(
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
k_dram_window,
number<-1>{},
k_oob_ck,
k_pre_np);
move_tile_window(k_dram_window, {0, kK0});
}
// tail

View File

@@ -3,7 +3,15 @@
#pragma once
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"

View File

@@ -0,0 +1,421 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <string>
#include <type_traits>
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// * Note on token_id_per_expert/sorted_token_ids_ptr data:
// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from
// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
//
// 32bit 0........23 24.....31 bit
// (data) -> (token_id | topk_id)
// low 24 bit is for token id, top 8 bit is for topk id
//
// the input after smooth-quant is [token, topk, hidden_dim], originally it is [token, hidden_dim]
// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
//
// [indexing implementation-2]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// we generate original rol/col id as
// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
// let x be one element of above, we can get:
// tpok_row_id(token_id) = x % num_tokens(5)
// tpok_col_id(expert_Id) = x / num_tokens
// topk_row_id/col_id can be used to access original topk_ids/topk_weight
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// we can get permuted_rc_ids:
// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
//
//
// clang-format on
//
namespace ck_tile {
// m: num_tokens (or token*input-batch)
// k: intermediate_size
// n: intermediate_size used between 2 FC (TP slice this)
// e: num expert
// if doing pre-shuffle
// nr : n / Block_Nr
// kr : k / Block_Kr
// w : fattened 1d wave buffer
struct FusedMoeGemmHostArgs
{
const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token
const void* sorted_token_ids_ptr; // [max_num_tokens_padded]
const void* sorted_weight_ptr; // [max_num_tokens_padded]
const void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
const void* num_sorted_tiles_ptr; // [1]
index_t hidden_size; // k
index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
// This is scatter/gather b2b group-gemm
template <typename Partitioner_, typename Pipeline_, typename Epilogue_>
struct FusedMoeGemmKernel
{
using Partitioner = remove_cvref_t<Partitioner_>;
using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>; // TODO: not used
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
// static_assert(kBlockPerCu > 0);
using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape
static constexpr index_t BlockSize_ = BlockShape::BlockSize;
using ADataType = typename Pipeline::Problem::ADataType;
using GDataType = typename Pipeline::Problem::GDataType;
using DDataType = typename Pipeline::Problem::DDataType;
using AccDataType = typename Pipeline::Problem::AccDataType;
using ODataType = typename Pipeline::Problem::ODataType;
using AScaleDataType = typename Pipeline::Problem::AScaleDataType;
using GScaleDataType = typename Pipeline::Problem::GScaleDataType;
using DScaleDataType = typename Pipeline::Problem::DScaleDataType;
using YSmoothScaleDataType = typename Pipeline::Problem::YSmoothScaleDataType;
using TopkWeightDataType = typename Pipeline::Problem::TopkWeightDataType;
using IndexDataType = typename Pipeline::Problem::IndexDataType;
using YDataType = typename Pipeline::Problem::YDataType;
using Traits = typename Pipeline::Problem::Traits;
static constexpr bool UseUK = true;
static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<int8_t> { static constexpr const char * name = "int8"; };
// clang-format on
CK_TILE_HOST static std::string GetName()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using S_ = BlockShape;
auto prec_str = [&] () {
std::string base_str = _SS_(t2s<ADataType>::name);
if (!std::is_same_v<ADataType, GDataType>) {
base_str += _SS_("_") + _SS_(t2s<GDataType>::name);
}
return base_str;
}();
return _SS_("fused_moe_") + _SS_(prec_str) + "_" +
_TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" +
_TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" +
_TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name);
#undef _SS_
#undef _TS_
// clang-format on
}
struct FusedMoeGemmKargs
{
const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token
const void* sorted_token_ids_ptr;
const void* sorted_weight_ptr;
const void* sorted_expert_ids_ptr;
const void* num_sorted_tiles_ptr;
index_t hidden_size; // k
index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
// TODO: switch karg based on
using Kargs = FusedMoeGemmKargs;
using Hargs = FusedMoeGemmHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
// TODO: hargs/kargs not guranteed to be the same
return bit_cast<Kargs>(hargs);
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{
constexpr index_t block_m = BlockShape::Block_M0;
int max_num_tokens_padded =
hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
if constexpr(UseUK)
{
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
num_sorted_tiles = num_sorted_tiles / BlockShape::Block_M0;
const auto [sorted_tile_id, intermediate_tile_id] =
Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
// if(threadIdx.x == 0)
// printf("bid:%d,%d, num_sorted_tiles:%d, sorted_tile_id:%d(%d),
// intermediate_tile_id:%d\n", static_cast<int>(blockIdx.x),
// static_cast<int>(blockIdx.y), num_sorted_tiles, sorted_tile_id, sorted_tile_id >=
// num_sorted_tiles? 1 : 0, intermediate_tile_id);
if(sorted_tile_id >= num_sorted_tiles)
return;
Pipeline{}(kargs, smem, sorted_tile_id, intermediate_tile_id);
}
else
{
// allocate LDS
// __shared__ char smem_ptr[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
index_t kr_1 =
kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
// note this is in unit of tile, need multiple tile size to get the index
const auto [sorted_tile_id, intermediate_tile_id] =
Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
if(sorted_tile_id >= num_sorted_tiles)
return;
const IndexDataType expert_id =
__builtin_amdgcn_readfirstlane(reinterpret_cast<const IndexDataType*>(
kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
// index along intermediate_size
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// BlockShape::Block_N0);
index_t interm_idx_nr =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto sorted_token_id =
a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
index_t token_id =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
auto topk_weight = reinterpret_cast<const TopkWeightDataType*>(
kargs.sorted_weight_ptr)[sorted_token_id];
const auto a_window = [&]() {
// A is already pre-padded in previous kernel
const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr);
const auto a_view_ = make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1),
number<Pipeline::kAlignmentA>{},
number<1>{});
// gather is here use indexing transform
const auto a_gather_view_ = transform_tensor_view(
a_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto a_window_ = make_tile_window(
a_gather_view_,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{0, 0});
return a_window_;
}();
// TODO: gtile using NSub to have less register pressure
const auto g_window = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr * kr_0 * BlockShape::Block_W0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<Pipeline::kAlignmentG>{},
number<1>{});
const auto g_view_1_ =
pad_tensor_view(g_view_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
const auto g_window_ = make_tile_window(g_view_1_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
{0, 0, 0});
return g_window_;
}();
const auto d_window = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 +
interm_idx_nr * BlockShape::Block_W1;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(nr_1, kr_1, BlockShape::Block_W1),
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
number<Pipeline::kAlignmentD>{},
number<1>{});
const auto d_view_1_ =
pad_tensor_view(d_view_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
sequence<PadHiddenSize, PadIntermediateSize, 0>{});
const auto d_window_ = make_tile_window(d_view_1_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
{0, 0, 0});
return d_window_;
}();
auto o_window = [&]() {
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
auto o_view_ = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
o_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1),
number<Pipeline::kAlignmentO>{},
number<1>{});
// gather is here
auto o_scatter_view_ = transform_tensor_view(
o_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto o_window_ = make_tile_window(
o_scatter_view_,
make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
{0, 0});
return o_window_;
}();
// do compute yeah
Pipeline{}(a_window,
g_window,
d_window,
o_window,
topk_weight,
smem,
kargs.hidden_size,
kargs.intermediate_size,
kargs.stride_token);
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,125 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
tensors:
1. act (A): input feature map
2. gate (G): B matrix for first gemm, output will do activation(Silu)
3. up (U): B matrix for first gemm
4. down (D): B matrix for second gemm
N1
/ \
+----------+ |
| Down | |
x----------x |
hidden hidden K1 | | |
N0 N0 x----------x |
| +------x-----x------+------x-----x------+ | | |
dim | | Gate | | | Up | | | | | |
contiguous | | | | | | | | | | |
| | | | | | | | | | |
v +------x-----x------+------x-----x------+ +----------+ V
K0 | | | | | contiguous
/ \ v v v v |
+---------+ +------x-----x------+------x-----x------+ |
M0 | A | | | | | | | | |
+---------+ +------x-----x------+------x-----x------+ |
----------> | | |
contiguous | V V
| x-----x +----------+
+------------> M1 | Y | ---------> | Out(O) |
ACT x-----x +----------+
K1 = N0 dim
* Note: Act could be Gelu/Silu/...
* Note: some model does not have Up
*/
template <typename BlockTile_0_,
typename WarpPerBlock_0_,
typename WarpTile_0_,
typename BlockTile_1_,
typename WarpPerBlock_1_,
typename WarpTile_1_>
struct FusedMoeGemmShape
{
using BlockTile_0 = remove_cvref_t<BlockTile_0_>;
using WarpPerBlock_0 = remove_cvref_t<WarpPerBlock_0_>;
using WarpTile_0 = remove_cvref_t<WarpTile_0_>;
using BlockTile_1 = remove_cvref_t<BlockTile_1_>;
using WarpPerBlock_1 = remove_cvref_t<WarpPerBlock_1_>;
using WarpTile_1 = remove_cvref_t<WarpTile_1_>;
static constexpr index_t NumWarps =
reduce_on_sequence(WarpPerBlock_0{}, multiplies{}, number<1>{});
// TODO: we don't support half warps aound to 1 warp here
static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{}));
static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{});
static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{});
static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{});
static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(number<0>{});
static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(number<1>{});
static constexpr index_t WarpPerBlock_K0 = WarpPerBlock_0::at(number<2>{});
static constexpr index_t Warp_M0 = WarpTile_0::at(number<0>{});
static constexpr index_t Warp_N0 = WarpTile_0::at(number<1>{});
static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{});
static constexpr index_t ThreadPerBlock_M0 = Warp_M0 * WarpPerBlock_M0;
static constexpr index_t ThreadPerBlock_N0 = Warp_N0 * WarpPerBlock_N0;
static constexpr index_t ThreadPerBlock_K0 = Warp_K0 * WarpPerBlock_K0;
static_assert(Block_M0 % ThreadPerBlock_M0 == 0);
static_assert(Block_N0 % ThreadPerBlock_N0 == 0);
static_assert(Block_K0 % ThreadPerBlock_K0 == 0);
static constexpr index_t Repeat_M0 = Block_M0 / ThreadPerBlock_M0;
static constexpr index_t Repeat_N0 = Block_N0 / ThreadPerBlock_N0;
static constexpr index_t Repeat_K0 = Block_K0 / ThreadPerBlock_K0;
static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{});
static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{});
static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{});
static constexpr index_t WarpPerBlock_M1 = WarpPerBlock_1::at(number<0>{});
static constexpr index_t WarpPerBlock_N1 = WarpPerBlock_1::at(number<1>{});
static constexpr index_t WarpPerBlock_K1 = WarpPerBlock_1::at(number<2>{});
static constexpr index_t Warp_M1 = WarpTile_1::at(number<0>{});
static constexpr index_t Warp_N1 = WarpTile_1::at(number<1>{});
static constexpr index_t Warp_K1 = WarpTile_1::at(number<2>{});
static constexpr index_t ThreadPerBlock_M1 = Warp_M1 * WarpPerBlock_M1;
static constexpr index_t ThreadPerBlock_N1 = Warp_N1 * WarpPerBlock_N1;
static constexpr index_t ThreadPerBlock_K1 = Warp_K1 * WarpPerBlock_K1;
static_assert(Block_M1 % ThreadPerBlock_M1 == 0);
static_assert(Block_N1 % ThreadPerBlock_N1 == 0);
static_assert(Block_K1 % ThreadPerBlock_K1 == 0);
static constexpr index_t Repeat_M1 = Block_M1 / ThreadPerBlock_M1;
static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1;
static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1;
static constexpr index_t BlockSize = warpSize * NumWarps;
// some assert
static_assert(Block_M0 == Block_M1);
static_assert(Block_N0 == Block_K1 || (Block_N0 / 2) == Block_K1); // Gate Only or Gate+Up
// pre-shuffle tile size compute (assume only for B matrix)
// we flatten the each wave tile to a 1d linear tensor(at model loading time)
// e.g. originally we have Block_N*Block_K tile size, after pre-shuffle
// we can have Block_Nr*Block_Kr*Block_W, where Block_W is Warp_N*Warp_K,
// and Block_Nr=Block_N/Warp_N, Block_Kr=Block_K/Warp_K
static constexpr index_t Block_W0 = Warp_N0 * Warp_K0;
static constexpr index_t Block_Nr0 = Block_N0 / Warp_N0;
static constexpr index_t Block_Kr0 = Block_K0 / Warp_K0;
static constexpr index_t Block_W1 = Warp_N1 * Warp_K1;
static constexpr index_t Block_Nr1 = Block_N1 / Warp_N1;
static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1;
static_assert(Block_W0 == Block_W1);
// static_assert(Block_Nr0 == Block_Kr1);
};
} // namespace ck_tile

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
template <typename BlockShape_>
struct FusedMoeGemmTilePartitioner_Linear
{
// FusedMoeGemmShape
using BlockShape = ck_tile::remove_cvref_t<BlockShape_>;
static constexpr const char* name = "lin";
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*num_sorted_tiles*/,
ck_tile::index_t /*intermediate_size*/)
{
index_t i_n = blockIdx.x;
index_t i_m = blockIdx.y;
return ck_tile::make_tuple(i_m, i_n);
}
CK_TILE_HOST static constexpr auto GridSize(index_t max_tokens, index_t intermediate_size)
{
// TODO: this may need tuning
index_t ms = ck_tile::integer_divide_ceil(max_tokens, BlockShape::Block_M0);
index_t ns = ck_tile::integer_divide_ceil(intermediate_size, BlockShape::Block_N0);
return dim3(ns, ms, 1);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,651 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
namespace ck_tile {
/*
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
<----- gemm-N ------>
+----+----+----+----+
| w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+
*/
template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
struct FusedMoeGemmPipeline_FlatmmEx
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
using ADataType = typename Problem::ADataType;
using GDataType = typename Problem::GDataType;
using DDataType = typename Problem::DDataType;
using AccDataType = typename Problem::AccDataType;
using ODataType = typename Problem::ODataType;
using AScaleDataType = typename Problem::AScaleDataType;
using GScaleDataType = typename Problem::GScaleDataType;
using DScaleDataType = typename Problem::DScaleDataType;
using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
using TopkWeightDataType = typename Problem::TopkWeightDataType;
using IndexDataType = typename Problem::IndexDataType;
using YDataType = typename Problem::YDataType;
using Traits = typename Problem::Traits;
static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
static constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
static constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
static constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
static constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
// minimize occupancy
return 2;
}
}();
static constexpr const char* name = "fused_moe_flatmm";
// TODO: there are multiple buffers
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
{
return Policy::template GetSmemSize_A<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetACoord()
{
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
const auto a_coord = a_dist.calculate_index();
return a_coord;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetOCoord()
{
constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
const auto o_coord = o_dist.calculate_index();
return o_coord;
}
template <typename AWindow, typename GWindow, typename DWindow, typename OWindow>
CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
const GWindow& g_window_,
const DWindow& d_window_,
OWindow& o_window_,
TopkWeightDataType /*topk_weight*/,
CK_TILE_LDS_ADDR void* smem,
index_t hidden_size,
index_t intermediate_size)
{
_Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wc++20-extensions\"");
constexpr auto NEG1 = number<-1>{};
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto TRUE = bool_constant<true>{};
constexpr auto FALSE = bool_constant<false>{};
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
CK_TILE_LDS_ADDR ADataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(
reinterpret_cast<CK_TILE_LDS_ADDR char*>(smem) +
Policy::template GetSmemSize_A<Problem>());
auto g_view = g_window_.get_bottom_tensor_view();
auto u_view = [&]() {
if constexpr(IsGateOnly)
{
return g_view;
}
else
{
index_t nr_0 = intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = hidden_size / BlockShape::Block_Kr0;
const GDataType* g_ptr =
g_window_.get_bottom_tensor_view().get_buffer_view().p_data_;
const GDataType* u_ptr = g_ptr + (nr_0 / 2) * kr_0 * number<BlockShape::Block_W0>{};
const auto u_view_ = make_naive_tensor_view<address_space_enum::global>(
u_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<kAlignmentG>{},
number<1>{});
const auto u_view_1_ =
pad_tensor_view(u_view_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
return u_view_1_;
}
}();
auto a_win = make_tile_window_linear(
a_window_, Policy::template MakeGlobalTileDistribution_A<Problem>());
auto g_win =
make_tile_window_linear(g_window_,
Policy::template MakeGlobalTileDistribution_G<Problem>(),
sequence<0, 1, 1>{});
auto d_win =
make_tile_window_linear(d_window_,
Policy::template MakeGlobalTileDistribution_D<Problem>(),
sequence<0, 1, 1>{});
auto o_win = make_tile_window_linear(
o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
using g_thread_type = decltype(load_tile(g_win));
using d_thread_type = decltype(load_tile(d_win));
using WarpGemm0 = decltype(Policy::template GetWarpGemm0<Problem>());
using WarpGemm1 = decltype(Policy::template GetWarpGemm1<Problem>());
auto warp_gemm_0 = WarpGemm0{};
auto warp_gemm_1 = WarpGemm1{};
// issues_warps_lanes
auto a_sst_win0 =
make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()),
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
{0, 0, 0});
auto a_sst_win1 =
make_tile_window(make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()),
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
{0, 0, 0});
// m*k
auto a_sld_win0 = [&]() {
using WG = WarpGemm0;
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<>,
tuple<sequence<BlockShape::Repeat_M0, BlockShape::WarpPerBlock_M0>,
sequence<BlockShape::Repeat_K0>>,
tuple<sequence<1>>,
tuple<sequence<1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_outer_dstr_enc, typename WG::AWarpDstrEncoding{});
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsLoadDesc_A<Problem>()),
Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
{0, 0},
make_static_tile_distribution(a_block_dstr_encode));
}();
// m*k
auto a_sld_win1 = [&]() {
using WG = WarpGemm0;
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
sequence<>,
tuple<sequence<BlockShape::Repeat_M0, BlockShape::WarpPerBlock_M0>,
sequence<BlockShape::Repeat_K0>>,
tuple<sequence<1>>,
tuple<sequence<1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_outer_dstr_enc, typename WG::AWarpDstrEncoding{});
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsLoadDesc_A<Problem>()),
Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
{0, 0},
make_static_tile_distribution(a_block_dstr_encode));
}();
auto bridge_sst_win = [&]() {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<YDataType*>(smem),
Policy::template MakeBridgeLdsStoreDesc<Problem>()),
Policy::template MakeBridgeLdsStoreDesc<Problem>().get_lengths(),
{0, 0});
}();
auto bridge_sld_win = [&]() {
return make_tile_window_linear(
make_tensor_view<address_space_enum::lds>(
reinterpret_cast<YDataType*>(smem),
Policy::template MakeBridgeLdsLoadDesc<Problem>()),
Policy::template MakeBridgeLdsLoadDesc<Problem>().get_lengths(),
{0, 0},
Policy::template MakeYTileDistribution<Problem>());
}();
// also OK with C array, 2 register buffer
statically_indexed_array<g_thread_type, 2> gs;
constexpr auto issues_a = number<a_win.get_num_of_access()>{};
constexpr auto issues_g = number<g_win.get_num_of_access()>{};
// constexpr auto issues_d = number<d_win.get_num_of_access()>{};
// constexpr auto issues_o = number<o_win.get_num_of_access()>{};
constexpr auto issues_gemm0 =
number<BlockShape::Repeat_M0 * BlockShape::Repeat_N0 * BlockShape::Repeat_K0 *
warp_gemm_0.get_num_of_access()>{};
constexpr auto issues_gemm1 =
number<BlockShape::Repeat_M1 * BlockShape::Repeat_N1 * BlockShape::Repeat_K1 *
warp_gemm_1.get_num_of_access()>{};
// constexpr auto issues_sld_a = number<a_sld_win0.get_num_of_access()>{};
const index_t num_blocks_k0 =
(hidden_size + BlockShape::Block_K0 - 1) / BlockShape::Block_K0;
const index_t num_blocks_n1 =
(hidden_size + BlockShape::Block_N1 - 1) / BlockShape::Block_N1;
using a_thread_type = decltype(load_tile(a_sld_win0));
statically_indexed_array<a_thread_type, 2> as;
auto gld_a = [&]<typename PreNop = bool_constant<false>>(
auto& a_store_, auto i_access, PreNop = {})
{
async_load_tile_raw(a_store_, a_win, i_access, PreNop{});
};
auto move_a = [&]() {
move_tile_window(a_win, {number<0>{}, number<BlockShape::Block_K0>{}});
};
auto sld_a = [&](auto& a_, auto& win_, auto i_access) {
load_tile_raw(a_, win_, i_access);
};
auto gld_g = [&]<typename PreNop = bool_constant<false>>(
auto& g_, auto i_access, PreNop = {})
{
if constexpr(IsGateOnly)
{
// TODO: hack!
if constexpr(i_access.value == 0)
{
g_win.bottom_tensor_view_ = g_view;
}
else if constexpr(i_access.value == issues_g / 2)
{
g_win.bottom_tensor_view_ = u_view;
}
}
load_tile_raw(g_, g_win, i_access, FALSE, PreNop{});
};
auto move_g = [&]() {
move_tile_window(g_win, {number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}});
};
statically_indexed_array<d_thread_type, 2> ds;
auto gld_d = [&]<typename PreNop = bool_constant<false>>(
auto& d_, auto i_access, PreNop = {})
{
load_tile_raw(d_, d_win, i_access, FALSE, PreNop{});
};
auto move_d = [&]() {
// d move along gemm-n
move_tile_window(d_win, {number<BlockShape::Block_N1>{}, number<0>{}});
};
auto atomic_add_o = [&]<typename PreNop = bool_constant<false>>(
auto& o_, auto i_access, PreNop = {})
{
update_tile_raw(o_win, o_, i_access, TRUE, PreNop{});
};
auto acc_0 = Policy::template MakeCBlockTile_Gemm0<Problem>();
auto acc_1s = generate_tuple(
[&](auto) { return Policy::template MakeCBlockTile_Gemm1<Problem>(); }, number<2>{});
// clang-format off
auto gemm_0 = [&]<typename PostNop = bool_constant<false>>
(auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
using WarpGemm = remove_cvref_t<decltype(warp_gemm_0)>;
constexpr auto repeat_sub = WarpGemm::get_num_of_access();
constexpr auto repeat_m = BlockShape::Repeat_M0;
// constexpr auto repeat_n = BlockShape::Repeat_N0;
constexpr auto repeat_k = BlockShape::Repeat_K0;
// loop order n->m->k
constexpr auto i_sub = i_access % repeat_sub;
constexpr auto i_k = (i_access / repeat_sub) % repeat_k;
constexpr auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
constexpr auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
AWarpTensor w_a;
w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
BWarpTensor w_b;
w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
merge_sequences(sequence<i_n, i_k>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
CWarpTensor w_c;
w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
warp_gemm_0(w_c, w_a, w_b, number<i_sub>{}, PostNop{});
t_c.set_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
w_c.get_thread_buffer());
};
// clang-format on
// clang-format off
auto gemm_1 = [&]<typename PostNop = bool_constant<false>>
(auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
using WarpGemm = remove_cvref_t<decltype(warp_gemm_1)>;
constexpr auto repeat_sub = WarpGemm::get_num_of_access();
constexpr auto repeat_m = BlockShape::Repeat_M0;
// constexpr auto repeat_n = BlockShape::Repeat_N0;
constexpr auto repeat_k = BlockShape::Repeat_K0;
// loop order n->m->k
constexpr auto i_sub = i_access % repeat_sub;
constexpr auto i_k = (i_access / repeat_sub) % repeat_k;
constexpr auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
constexpr auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
AWarpTensor w_a;
w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
BWarpTensor w_b;
w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
merge_sequences(sequence<i_n, i_k>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
CWarpTensor w_c;
w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
warp_gemm_1(w_c, w_a, w_b, number<i_sub>{}, PostNop{});
t_c.set_y_sliced_thread_data(
merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
w_c.get_thread_buffer());
};
// clang-format on
_Pragma("clang diagnostic pop");
// this gemm pipeline is designed with assumption that issues of buffer-load/ds_read can
// be hide under mfma. In other words, issues of mfma is >= memory this is true if we
// pre-shuffle B matrix, and A matrix is relatively small we prefer use multiple mfma
// paired with 1 buffer-load B matrix, to get max throughput of buffer_load. and by
// preshuffle, we always pack to dwordx4 load, and this will already extend to multiple
// mfma but that is already consumed inside warpgemm-impl. So indeed how many extra
// mfma(that can reuse the B matrix) only affected by M repeat.
auto pipeline_gemm0 = [&]() {
constexpr index_t total_loops = issues_gemm0;
constexpr auto sr = Policy::template GetSequencer_0<Problem>();
static_assert(sr.size() == total_loops);
constexpr auto c_sld_a_0 = MAKE_SC();
constexpr auto c_gld_a_0 = MAKE_SC();
constexpr auto c_gld_b_0 = MAKE_SC();
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I0], gs[I0], i_issue);
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & SLD_A)
sld_a(as[I1], a_sld_win1, number<NEXT_SCI(c_sld_a_0, i_issue)>{});
if constexpr(slot & GLD_A)
gld_a(a_sst_win0, number<NEXT_SCI(c_gld_a_0, i_issue)>{});
if constexpr(slot & GLD_B)
gld_g(gs[I0], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
});
move_g();
move_a();
block_sync_load_raw(issues_a + issues_g);
lds_load_fence();
constexpr auto c_sld_a_1 = MAKE_SC();
constexpr auto c_gld_a_1 = MAKE_SC();
constexpr auto c_gld_b_1 = MAKE_SC();
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I1], gs[I1], i_issue);
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & SLD_A)
sld_a(as[I0], a_sld_win0, number<NEXT_SCI(c_sld_a_1, i_issue)>{});
if constexpr(slot & GLD_A)
gld_a(a_sst_win1, number<NEXT_SCI(c_gld_a_1, i_issue)>{});
if constexpr(slot & GLD_B)
gld_g(gs[I1], number<NEXT_SCI(c_gld_b_1, i_issue)>{});
});
move_g();
move_a();
block_sync_load_raw(issues_a + issues_g);
lds_load_fence();
};
auto pipeline_gemm0_tail = [&]() {
constexpr index_t total_loops = issues_gemm0;
constexpr auto sr = Policy::template GetSequencer_0<Problem>();
static_assert(sr.size() == total_loops);
constexpr auto c_gld_b_0 = MAKE_SC();
// compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_0(acc_0, as[I0], gs[I0], i_issue);
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & GLD_B)
gld_g(gs[I1], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
});
block_sync_load_raw(issues_g);
sld_a(as[I1], a_sld_win1, NEG1);
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
constexpr auto last_nop = [&]() {
if constexpr(i_issue == (total_loops - 1))
return TRUE;
else
return FALSE;
}();
gemm_0(acc_0, as[I1], gs[I1], i_issue, last_nop); // last gemm has nop
});
};
auto y = Policy::template MakeYBlockTile<Problem>();
auto pipeline_bridge = [&]() {
// cast to Y data
auto y_pre = cast_tile<YDataType>(acc_0);
store_tile(bridge_sst_win, y_pre);
clear_tile(acc_1s(I0));
// wave_barrier();
load_tile(y, bridge_sld_win);
clear_tile(acc_1s(I1));
};
// note, gemm-1 start from idx-1 to N-2 (0, 1, 2....N-1)
auto pipeline_gemm1 = [&]() {
constexpr index_t total_loops = issues_gemm1;
constexpr auto sr = Policy::template GetSequencer_1<Problem>();
static_assert(sr.size() == total_loops);
constexpr auto c_gld_b_0 = MAKE_SC();
constexpr auto c_gst_o_0 = MAKE_SC();
constexpr auto c_gld_b_1 = MAKE_SC();
constexpr auto c_gst_o_1 = MAKE_SC();
// compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I1], y, ds[I1], i_issue);
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & GLD_B)
gld_d(ds[I0], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
if constexpr(slot & GST_O)
{
auto out = cast_tile<ODataType>(acc_1s[I0]);
atomic_add_o(out, number<NEXT_SCI(c_gst_o_0, i_issue)>{});
}
});
move_d();
// move_o();
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I0], y, ds[I0], i_issue);
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & GLD_B)
gld_d(ds[I1], number<NEXT_SCI(c_gld_b_1, i_issue)>{});
if constexpr(slot & GST_O)
{
auto out = cast_tile<ODataType>(acc_1s[I1]);
atomic_add_o(out, number<NEXT_SCI(c_gst_o_1, i_issue)>{});
}
});
move_d();
};
auto pipeline_gemm1_head = [&]() {
constexpr index_t total_loops = issues_gemm1;
constexpr auto sr = Policy::template GetSequencer_1<Problem>();
static_assert(sr.size() == total_loops);
constexpr auto c_gld_b_0 = MAKE_SC();
// compute buffer 0
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I0], y, ds[I0], i_issue);
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & GLD_B)
gld_d(ds[I1], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
});
move_d();
};
auto pipeline_gemm1_tail = [&]() {
constexpr index_t total_loops = issues_gemm1;
constexpr auto sr = Policy::template GetSequencer_1<Problem>();
static_assert(sr.size() == total_loops);
constexpr auto c_gst_o_0 = MAKE_SC();
// compute buffer 1
static_for<0, total_loops, 1>{}([&](auto i_issue) {
gemm_1(acc_1s[I1], y, ds[I1], i_issue);
constexpr index_t slot = sr.at(i_issue);
if constexpr(slot & GST_O)
{
auto out = cast_tile<ODataType>(acc_1s[I0]);
atomic_add_o(out, number<NEXT_SCI(c_gst_o_0, i_issue)>{});
}
});
{
auto out = cast_tile<ODataType>(acc_1s[I1]);
atomic_add_o(out, NEG1);
}
};
// start of pipeline
// clang-format off
gld_a(a_sst_win0, NEG1, TRUE);
gld_g(gs[I0], NEG1, TRUE);
move_a();
move_g();
clear_tile(acc_0);
// preload for next round
gld_a(a_sst_win1, NEG1);
gld_g(gs[I1], NEG1);
// make sure a,g loaded
block_sync_load_raw(issues_a + issues_g);
lds_load_fence();
// we manually unroll double buffer inside hot loop
const index_t iters_0 = (num_blocks_k0 - 2) / 2;
index_t i_0 = 0; // (void)i_0; (void)iters_0; (void)pipeline_gemm0;
while(i_0++ < iters_0)
{
pipeline_gemm0();
}
pipeline_gemm0_tail();
pipeline_bridge();
const index_t iters_1 = (num_blocks_n1 - 2) / 2;
index_t i_1 = 0; // (void) i_1; (void)iters_1; (void)pipeline_gemm1;
pipeline_gemm1_head();
while(i_1++ < iters_1)
{
pipeline_gemm1();
}
pipeline_gemm1_tail();
// clang-format on
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,831 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile {
struct FusedMoeGemmPipelineFlatmmPolicy
{
CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords()
{
// TODO: always 1 dword
return 1;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A()
{
// using async
constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords();
constexpr index_t data_bytes = sizeof(typename Problem::ADataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_G()
{
constexpr index_t copy_bytes = [&]() { return 16; }();
constexpr index_t data_bytes = sizeof(typename Problem::GDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_D()
{
constexpr index_t copy_bytes = [&]() { return 16; }();
constexpr index_t data_bytes = sizeof(typename Problem::DDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_O()
{
if constexpr(Problem::Traits::OAtomic == 1)
{
// pack fp16/bf16 atomic
static_assert(sizeof(typename Problem::ODataType) == 2);
return 2;
}
else if constexpr(Problem::Traits::OAtomic == 2)
{
// fp32 atomic
return 1;
}
else
{
return 16 / sizeof(typename Problem::ODataType);
}
}
template <typename DataType_>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack()
{
// TODO: this is for 3d layout
return 16 / sizeof(remove_cvref_t<DataType_>);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_A()
{
return GetSmemKPack<typename Problem::ADataType>();
}
// used for bridge LDS shuffle
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_Y()
{
// TODO: this should match mfma layout
return 16 / sizeof(typename Problem::YDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
{
constexpr auto a_sld_desc = MakeLdsLoadDesc_A<Problem>();
constexpr auto a_sst_desc = MakeLdsStoreDesc_A<Problem>();
static_assert(a_sld_desc.get_element_space_size() == a_sst_desc.get_element_space_size());
return a_sld_desc.get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_Bridge()
{
constexpr auto bridge_sld_desc = MakeBridgeLdsLoadDesc<Problem>();
constexpr auto bridge_sst_desc = MakeBridgeLdsStoreDesc<Problem>();
static_assert(bridge_sld_desc.get_element_space_size() ==
bridge_sst_desc.get_element_space_size());
return bridge_sld_desc.get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
constexpr index_t a_lds = GetSmemSize_A<Problem>();
constexpr index_t bridge_lds = GetSmemSize_Bridge<Problem>();
return max(a_lds, bridge_lds);
}
template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK()
{
constexpr index_t K_vec = Alignment;
constexpr index_t K_rem = KPerBlock / K_vec;
if constexpr(get_warp_size() < K_rem)
{
static_assert(K_rem % get_warp_size() == 0);
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
constexpr index_t K_wav = K_rem / get_warp_size();
static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet");
constexpr index_t M_wav = NumWarps / K_wav;
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / M_wav;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
tuple<sequence<1, 2>, sequence<2>>,
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else
{
constexpr index_t K_lan = K_rem;
constexpr index_t M_lan = get_warp_size() / K_lan;
constexpr index_t M_wav = NumWarps;
static_assert(MPerBlock % (M_lan * M_wav) == 0,
"this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav, M_lan>, sequence<K_lan, K_vec>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
// optimized version for async, not same as simple MXK dist(pay attention!!)
template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK_Async()
{
constexpr index_t K_vec = Alignment;
constexpr index_t K_rem = KPerBlock / K_vec;
if constexpr(get_warp_size() <= K_rem)
{
static_assert(K_rem % get_warp_size() == 0);
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
constexpr index_t K_wav = K_rem / get_warp_size();
static_assert(K_wav <= NumWarps, "do not support thread has repeat along K yet");
constexpr index_t M_wav = NumWarps / K_wav;
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / M_wav;
// NOTE: no swap, but hard to avoid LDS bank conflict
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
tuple<sequence<1, 2>, sequence<2>>,
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else
{
constexpr index_t K_lan = K_rem;
constexpr index_t M_lan = get_warp_size() / K_lan;
constexpr index_t M_wav = NumWarps;
static_assert(MPerBlock % (M_lan * M_wav) == 0,
"this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
// NOTE: swapped for LDS load bank conflict free
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
// Note M_wave(num waves) is the fastest dim, different from sipmle 2d
// distribution
tuple<sequence<M_rep, M_lan, M_wav>, sequence<K_lan, K_vec>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
template <index_t WarpPerBlock_N_,
index_t WarpPerBlock_K_,
index_t Repeat_N_,
index_t Repeat_K_,
index_t WarpSize_,
index_t Alignment_>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_Nr_Kr_W()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<Repeat_N_, WarpPerBlock_N_>,
sequence<Repeat_K_, WarpPerBlock_K_>,
sequence<WarpSize_, Alignment_>>,
tuple<sequence<1, 2>, sequence<3>>,
tuple<sequence<1, 1>, sequence<0>>,
sequence<1, 2, 3>,
sequence<0, 0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_A()
{
constexpr index_t Block_M_ = Problem::BlockShape::Block_M0;
constexpr index_t Block_K_ = Problem::BlockShape::Block_K0;
constexpr index_t NumWarps_ = Problem::BlockShape::NumWarps;
constexpr index_t Alignment_ = GetAlignment_A<Problem>();
return MakeGlobalTileDistribution_SimpleMxK_Async<Block_M_,
Block_K_,
NumWarps_,
Alignment_>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
{
constexpr auto PermuteEnum = Problem::Traits::PermuteEnum;
// constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2;
using S_ = typename Problem::BlockShape;
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
// number<S_::WarpPerBlock_N0>{}.rrr();
// number<S_::Repeat_N0>{}.eee();
return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N0,
S_::WarpPerBlock_K0,
S_::Repeat_N0, /// hidden_radio_0,
S_::Repeat_K0,
get_warp_size(),
GetAlignment_G<Problem>()>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D()
{
constexpr auto PermuteEnum = Problem::Traits::PermuteEnum;
using S_ = typename Problem::BlockShape;
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
{
return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N1,
S_::WarpPerBlock_K1,
S_::Repeat_N1,
S_::Repeat_K1,
get_warp_size(),
GetAlignment_D<Problem>()>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
{
// A async->LDS
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack; // pad between warps
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK > NumWarps)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr index_t wavesPerM = NumWarps / wavesPerK;
constexpr index_t NumIssues = Block_M / wavesPerM;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_pass_through_transform(number<NumIssues>{}),
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<LaneGroups>{}, // m1
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1
number<warpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KVector>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
return lds_block_desc_issues_warps_lanes;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
{
// A async->LDS
// Note that, this descriptor is only to construct the layout inside LDS
// in real Gemm pipeline, ds_read may not follow this pattern
// (may follow that in tile_distribution)
// below code is almost the same as SmemStore dist, with difference:
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
// 2). return discriptor is in NxK 2d layout
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t warpSize = ck_tile::get_warp_size();
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack; // pad between warps
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= warpSize)
{
// need multiple waves to load K
static_assert(LanesPerK % warpSize == 0);
constexpr index_t wavesPerK = LanesPerK / warpSize;
if constexpr(wavesPerK >= NumWarps)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr index_t wavesPerM = NumWarps / wavesPerK;
constexpr index_t NumIssues = Block_M / wavesPerM;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<wavesPerM>{}, // m1
number<wavesPerK>{}, // k0
number<warpSize>{}, // k1
number<KVector>{}), // k2
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
number<warpSize * KVector + KPad>{}, // k0
number<KVector>{}, // k1
number<1>{}), // k2
number<KPack>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<NumIssues>{}, number<wavesPerM>{})),
make_merge_transform(make_tuple(
number<wavesPerK>{}, number<warpSize>{}, number<KVector>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
}
else
{
// lanes within a wave load different M but same K
static_assert(warpSize % LanesPerK == 0);
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumIssues>{}, // m0
number<LaneGroups>{}, // m1
number<NumWarps>{}, // m2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
number<Block_K>{}, // m1
number<warpSize * KVector + KPad>{}, // m2
number<KVector>{}, // k0
number<1>{}), // k1
number<KPack>{}, // lds load vector
number<1>{});
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(
make_merge_transform(
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_desc_m_k;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsLoadDesc()
{
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t KPad = 0; // pad between warps
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
make_tuple(number<Block_N + KPad>{}, number<1>{}),
number<KVector>{},
number<1>{});
return desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreDesc()
{
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
constexpr index_t KPad = 0; // KVector; // pad between warps
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
make_tuple(number<Block_N + KPad>{}, number<1>{}),
number<KVector>{},
number<1>{});
return desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreForUKDesc()
{
constexpr index_t WarpPerBlock_N = Problem::BlockShape::WarpPerBlock_N0;
constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N0;
constexpr index_t Repeat_M = Problem::BlockShape::Repeat_M0;
constexpr index_t kAMLane = 16;
constexpr index_t kABKLane = 4;
constexpr index_t kABKPerLane = 4;
constexpr index_t KPack = kABKPerLane;
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<Repeat_M>{}, // m
number<Repeat_N>{}, // n
number<WarpPerBlock_N>{}, // n
number<kABKLane>{}, // n
number<kAMLane>{}, // m
number<KPack>{}), // n
make_tuple(number<Repeat_N * WarpPerBlock_N * kABKLane * kAMLane * KPack>{}, // m
number<WarpPerBlock_N * kABKLane * kAMLane * KPack>{}, // n
number<kABKLane * kAMLane * KPack>{}, // n
number<kAMLane * KPack>{}, // n
number<KPack>{}, // m
number<1>{}), // n
number<KPack>{}, // lds store vector(actually no explicit store)
number<1>{});
constexpr auto desc = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_merge_transform(make_tuple(number<Repeat_M>{}, number<kAMLane>{})),
make_merge_transform(make_tuple(number<Repeat_N>{},
number<WarpPerBlock_N>{},
number<kABKLane>{},
number<KPack>{}))),
make_tuple(sequence<0, 4>{}, sequence<1, 2, 3, 5>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm0()
{
using S_ = typename Problem::BlockShape;
// A is vgpr, B is agpr. But since we transposed, so also need swap this
// TODO: this is ugly
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
// TODO: ugly
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
2>>{};
}
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::int8_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::int8_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<wg_ctrl>,
2>>{};
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_0()
{
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
// the purpose is to hide thoes instructions under mfma
// every value inside seq<...> is a mask, indicating a specific operation
using S_ = typename Problem::BlockShape;
constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
S_::Block_N1 == 128)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr auto seq_all =
// 0 1 2 3 4 5 6 7
sequence<GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 0
GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 1
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 2
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 3
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 4
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 5
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 6
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 7
return seq_all;
// clang-format on
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
S_::Block_N1 == 128)
{
// Total 32 instructions, 16 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr auto seq_all =
// 0 1 2 3 4 5 6 7
sequence<GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 0
GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 1
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 2
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A>{}; // 3
return seq_all;
// clang-format on
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_1()
{
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
// the purpose is to hide thoes instructions under mfma
// every value inside seq<...> is a mask, indicating a specific operation
using S_ = typename Problem::BlockShape;
constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
S_::Block_N1 == 128)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr auto seq_all =
// 0 1 2 3 4 5 6 7
sequence<GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 0
GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 1
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 2
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 3
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 4
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 5
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 6
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 7
return seq_all;
// clang-format on
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
S_::Block_N1 == 128)
{
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
// clang-format off
constexpr auto seq_all =
// 0 1 2 3 4 5 6 7
sequence<GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 0
GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 1
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 2
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 3
return seq_all;
// clang-format on
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1()
{
using S_ = typename Problem::BlockShape;
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
// TODO: ugly
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
2>>{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::int8_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::int8_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<wg_ctrl>,
2>>{};
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm0()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm0<Problem>())>;
using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<S_::Repeat_M0, S_::WarpPerBlock_M0>,
sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm1()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// this is used as A matrix for 2nd gemm
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// TODO: all waves a along different N, but same M
constexpr auto y_outer_dstr_enc =
tile_distribution_encoding<sequence<S_::WarpPerBlock_M1>,
tuple<sequence<S_::Repeat_M1>, sequence<S_::Repeat_K1>>,
tuple<sequence<0>>,
tuple<sequence<0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode);
return y_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeYBlockTile()
{
constexpr auto y_block_dstr = MakeYTileDistribution<Problem>();
auto y_block_tensor =
make_static_distributed_tensor<typename Problem::YDataType>(y_block_dstr);
return y_block_tensor;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetUK_0()
{
using S_ = typename Problem::BlockShape;
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return Flatmm_32x512x128_1x4x1_16x16x32_BF16{};
}
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::fp16_t> &&
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return Flatmm_32x512x128_1x4x1_16x16x32_FP16{};
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetUK_1()
{
using S_ = typename Problem::BlockShape;
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
{
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,354 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
namespace ck_tile {
/*
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
<----- gemm-N ------>
+----+----+----+----+
| w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+
*/
template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
struct FusedMoeGemmPipeline_FlatmmUk
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
using ADataType = typename Problem::ADataType;
using GDataType = typename Problem::GDataType;
using DDataType = typename Problem::DDataType;
using AccDataType = typename Problem::AccDataType;
using ODataType = typename Problem::ODataType;
using AScaleDataType = typename Problem::AScaleDataType;
using GScaleDataType = typename Problem::GScaleDataType;
using DScaleDataType = typename Problem::DScaleDataType;
using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
using TopkWeightDataType = typename Problem::TopkWeightDataType;
using IndexDataType = typename Problem::IndexDataType;
using YDataType = typename Problem::YDataType;
using Traits = typename Problem::Traits;
static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
static constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
static constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
static constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
static constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
// minimize occupancy
return 2;
}
}();
static constexpr const char* name = "flatmm_uk";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return max(smem_0, max(smem_1, smem_bridge));
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetACoord()
{
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
const auto a_coord = a_dist.calculate_index();
return a_coord;
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetOCoord()
{
constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
const auto o_coord = o_dist.calculate_index();
return o_coord;
}
CK_TILE_DEVICE constexpr auto GetNumRowCoords_A()
{
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
constexpr index_t MLans = BlockShape::BlockSize / KLans;
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
return MRepeat;
}
// TODO: properlly support scatter/gather
CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset)
{
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
constexpr index_t MLans = BlockShape::BlockSize / KLans;
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
auto base_coord = threadIdx.x / KLans + base_offset;
array<index_t, MRepeat> coords;
static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
return coords;
}
template <typename ROW_COORDS>
CK_TILE_DEVICE auto GetRowID(const ROW_COORDS coords, const IndexDataType* sorted_token_ids_ptr)
{
constexpr index_t n_size = coords.size();
array<index_t, n_size> row_ids;
static_for<0, n_size, 1>{}([&](auto i) {
row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
});
return row_ids;
}
template <typename ROW_COORDS>
CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords,
const TopkWeightDataType* sorted_weight_ptr)
{
constexpr index_t n_size = coords.size();
array<TopkWeightDataType, n_size> w;
static_for<0, n_size, 1>{}([&](auto i) {
w.at(i) = sorted_weight_ptr[coords[i]]; // base_coord + i * MLans;
});
return w;
}
// TODO: this row id is before shuffle atomic, need use acc distribution
CK_TILE_DEVICE auto GetRowCoords_O(index_t base_offset)
{
constexpr index_t MLanes = BlockShape::Warp_M1;
constexpr index_t Repeat_M = BlockShape::Repeat_M1;
auto base_coord = threadIdx.x % MLanes + base_offset;
array<index_t, Repeat_M> coords;
static_for<0, Repeat_M, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLanes; });
return coords;
}
template <typename Karg>
CK_TILE_DEVICE auto operator()(const Karg& kargs,
CK_TILE_LDS_ADDR void* smem,
index_t sorted_tile_id,
index_t intermediate_tile_id)
{
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size;
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0;
index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1;
index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1;
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
// nr*kr*w
index_t interm_idx_nr0 = __builtin_amdgcn_readfirstlane(
intermediate_tile_id *
BlockShape::Block_Nr0); // intermediate_tile_id * Block_N / (N in W)
index_t interm_idx_kr1 = __builtin_amdgcn_readfirstlane(
intermediate_tile_id *
BlockShape::Block_Kr1); // intermediate_tile_id * Block_N / (N in W)
auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0);
auto row_ids_a = GetRowID(
row_coords_a, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
auto a_coords = generate_tuple(
[&](auto i) {
return row_ids_a[i] * kargs.stride_token +
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
},
number<row_ids_a.size()>{});
auto a_res =
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
auto g_win = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
number<kAlignmentG>{},
number<1>{});
auto g_window_ = make_tile_window_linear_raw(
g_view_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
{0, 0, 0},
Policy::template MakeGlobalTileDistribution_G<Problem>(),
sequence<0, 1, 1>{});
return g_window_;
}();
auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
number<decltype(g_win)::NumAccess_NonLinear>{});
const auto d_win = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 +
interm_idx_kr1 * BlockShape::Block_W1;
// note interm_idx_nr0 is along the gemm-k dim of 2nd gemm
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(nr_1, kr_1, BlockShape::Block_W1),
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
number<kAlignmentD>{},
number<1>{});
const auto d_window_ = make_tile_window_linear_raw(
d_view_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
{0, 0, 0},
Policy::template MakeGlobalTileDistribution_D<Problem>(),
sequence<0, 1, 1>{});
return d_window_;
}();
auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
// TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255
// block-k=512, block-n=128
// wg |<----- W_ ----->|
// Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue
// y p y y p p y
// 1 2 0(imm)
auto d_coords = [&]() {
constexpr index_t Nr_ = 2;
constexpr index_t Nw_ = 4;
constexpr index_t Kr0_ = 4;
constexpr index_t Kr1_ = 4;
constexpr index_t Kl_ = 4;
constexpr index_t Nl_ = 16;
constexpr index_t Kv_ = 8;
constexpr index_t W_ = Kl_ * Nl_ * Kv_;
constexpr index_t num_offsets_ = Nr_ * Kr0_;
index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) *
shared_intermediate_size_1 *
Nl_; // Kr0_ * Kr1_ * W_;
return generate_tuple(
[&](auto i) {
constexpr auto i_nr_ = number<i % Nr_>{};
constexpr auto i_kr0_ = number<i / Nr_>{};
return i_nr_ * shared_intermediate_size_1 * Nw_ * Nl_ + i_kr0_ * Kr1_ * W_ +
base_os_;
},
number<num_offsets_>{});
}();
auto o_coords = generate_tuple(
[&](auto i) {
return row_ids_a[i] * kargs.stride_token +
threadIdx.x % (BlockShape::Block_N1 / kAlignmentO) * kAlignmentO;
},
number<row_ids_a.size()>{});
auto o_flags =
generate_tuple([&](auto i) { return cmp_lt_to_exec(row_ids_a[i], kargs.num_tokens); },
number<row_ids_a.size()>{});
auto bridge_sst_win = [&]() {
constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
return make_tile_window_linear(make_tensor_view<address_space_enum::lds>(
reinterpret_cast<YDataType*>(smem), desc_),
desc_.get_lengths(),
{0, 0},
dist_);
}();
auto o_res =
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ODataType));
auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0);
auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
auto uk_0 = Policy::template GetUK_0<Problem>();
auto acc_0 = uk_0(a_res,
a_coords,
g_res,
g_coords,
smem,
kargs.hidden_size,
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll
sweep_tile(
acc_0,
[&](auto idx0, auto idx1) {
fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
typename Problem::GateActivation{}(v_, v_);
acc_0(idx0) = v_.x;
acc_0(idx1) = v_.y;
},
sequence<1, 2>{});
auto y_pre = cast_tile<YDataType>(acc_0);
block_sync_lds();
store_tile(bridge_sst_win, y_pre);
block_sync_lds();
auto uk_1 = Policy::template GetUK_1<Problem>();
uk_1(d_res,
d_coords,
o_res,
o_coords,
o_flags,
smem,
kargs.hidden_size, // total n number
w_scale,
BlockShape::Block_Nr1 * kr_1 * BlockShape::Block_W1, // along N
BlockShape::Block_N1); // along N
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,46 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
// TODO: alow 2 gemm have different type
template <typename ADataType_,
typename GDataType_,
typename DDataType_,
typename AccDataType_,
typename ODataType_,
typename AScaleDataType_,
typename GScaleDataType_,
typename DScaleDataType_,
typename YSmoothScaleDataType_,
typename TopkWeightDataType_,
typename IndexDataType_, // data type for all indexing
typename GateActivation_, // = ck_tile::element_wise::Silu,
typename BlockShape_, // shoule be FusedMoeGemmShape
typename Traits_>
struct FusedMoeGemmPipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using GDataType = remove_cvref_t<GDataType_>;
using DDataType = remove_cvref_t<DDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using AScaleDataType = remove_cvref_t<AScaleDataType_>;
using GScaleDataType = remove_cvref_t<GScaleDataType_>;
using DScaleDataType = remove_cvref_t<DScaleDataType_>;
using YSmoothScaleDataType = remove_cvref_t<YSmoothScaleDataType_>;
using TopkWeightDataType = remove_cvref_t<TopkWeightDataType_>;
using IndexDataType = remove_cvref_t<IndexDataType_>;
// the input for next gemm should have same time as
using YDataType = ADataType;
using GateActivation = remove_cvref_t<GateActivation_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using Traits = remove_cvref_t<Traits_>;
};
} // namespace ck_tile

View File

@@ -0,0 +1,48 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
enum class FusedMoeGemmWeightPermuteEnum
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
no_permute = 0,
b_nr_kr_kw_nw_kv = 1, // 0,1,3,4,2,5
b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
};
template <bool IsGateOnly_,
bool UseSmoothQuant_,
index_t OAtomic_, // 0-no atomic, 1-atomic-pk-f16/bf16, 2-atomic-f32
FusedMoeGemmWeightPermuteEnum PermuteEnum_ =
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten,
bool PadHiddenSize_ = false,
bool PadIntermediateSize_ = false>
struct FusedMoeGemmTraits
{
// Gate+Up or Gate only
static constexpr bool IsGateOnly = IsGateOnly_;
static constexpr bool UseSmoothQuant = UseSmoothQuant_;
static constexpr index_t OAtomic = OAtomic_;
static constexpr FusedMoeGemmWeightPermuteEnum PermuteEnum = PermuteEnum_;
static constexpr bool PadHiddenSize = PadHiddenSize_;
static constexpr bool PadIntermediateSize = PadIntermediateSize_;
};
// Note: this need to be a bit mask
enum class FusedMoeGemmPipelineSequencerEnum
{
SLD_A = 1 << 0, // shared load a
SLD_B = 1 << 1,
GLD_A = 1 << 2, // global load a
GLD_B = 1 << 3,
SST_A = 1 << 4, // shared store a
SST_B = 1 << 5,
GST_O = 1 << 6, // global store out
};
} // namespace ck_tile

View File

@@ -10,114 +10,134 @@
namespace ck_tile {
// fp16
using WarpGemmMfmaF16F16F32M32N32K8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>;
using WarpGemmMfmaF16F16F32M16N16K16 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16>>;
using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M32N32K16 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M16N16K32 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>;
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 1>>;
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M16N16K16>>;
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16,
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8,
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
// bf16
using WarpGemmMfmaBf16Bf16F32M32N32K8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M16N16K32 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 1>>;
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16,
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8,
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
// fp8
using WarpGemmMfma_f32_32x32x16_fp8_fp8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8>>;
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8>>;
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
template <index_t swizzle_factor = 2>
using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t>,
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, WGAttrCtlEnum::Default_>,
2,
swizzle_factor>>;

View File

@@ -25,6 +25,8 @@ struct WarpGemmAtrributeMfma
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
@@ -51,10 +53,13 @@ struct WarpGemmAtrributeMfma
sequence<0, 2>>;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
Impl{}(c_vec, a_vec, b_vec);
Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
}
// c_vec = a_vec * b_vec
@@ -85,6 +90,8 @@ struct WarpGemmAtrributeMfmaIterateK
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
@@ -111,8 +118,11 @@ struct WarpGemmAtrributeMfmaIterateK
sequence<0, 2>>;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
@@ -122,10 +132,33 @@ struct WarpGemmAtrributeMfmaIterateK
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]);
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
});
}
template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
@@ -168,6 +201,8 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
@@ -194,11 +229,14 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
sequence<0, 2>>;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
// swap A and B
Impl{}(c_vec, b_vec, a_vec);
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
}
// c_vec = a_vec * b_vec
@@ -226,6 +264,8 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
@@ -255,12 +295,15 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
sequence<2, 2>,
sequence<0, 2>>;
template <bool post_nop_ = false>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
// swap A and B
Impl{}(c_vec, b_vec, a_vec);
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
}
// c_vec = a_vec * b_vec
@@ -291,6 +334,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
@@ -316,9 +361,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
sequence<2, 2>,
sequence<0, 2>>;
template <bool post_nop_ = false>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
@@ -328,10 +376,34 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter]);
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
});
}
template <index_t iKIter, bool post_nop_ = false>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
@@ -377,6 +449,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
@@ -429,8 +503,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
sequence<0, 2>>;
#endif
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
@@ -440,10 +517,33 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter]);
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
});
}
template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
@@ -488,6 +588,8 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
@@ -518,8 +620,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
sequence<0, 2>>;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
@@ -529,10 +634,33 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]);
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
});
}
template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -7,12 +7,68 @@
namespace ck_tile {
// TODO: refactor warp-gemm
// currently there is a discrepency for vav/vva if we need transpose C/D
// e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum
// because we swap the A/B pointer in _impl code (but not known this info here)
enum class WGAttrCtlEnum
{
Default_ = 0,
Raw_vvv = 1, // c-vgpr, a-vgpr, b-vgpr
Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr
Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr
Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr
// raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
};
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
if constexpr(post_nop_) \
{ \
asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \
"s_nop 3" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
:); \
} \
else \
{ \
asm volatile(mfma_ " %0, %1, %2, %3\n" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
:); \
}
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
{ \
DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
}
// FP16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
{
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>;
@@ -33,16 +89,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8f16", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
@@ -52,18 +115,20 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
#else
ignore = a_vec;
ignore = b_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
return CVecType{0.f};
#endif
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
{
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = fp16_t;
using BDataType = fp16_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>;
@@ -84,16 +149,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
@@ -103,19 +175,21 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else
ignore = a_vec;
ignore = b_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
return CVecType{0.f};
#endif
}
};
// Bf16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
{
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
using AVecType = ext_vector_t<bf16_t, 4>;
using BVecType = ext_vector_t<bf16_t, 4>;
@@ -136,28 +210,35 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8bf16_1k", Ctrl)
else
{
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#elif defined(__gfx908__)
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec,
0,
0,
0);
});
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec,
0,
0,
0);
});
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
@@ -181,18 +262,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
});
return c_vec;
#else
ignore = a_vec;
ignore = b_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
return CVecType{0.f};
#endif
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
{
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = bf16_t;
using BDataType = bf16_t;
using CDataType = float;
using AVecType = ext_vector_t<bf16_t, 4>;
using BVecType = ext_vector_t<bf16_t, 4>;
@@ -213,28 +296,34 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl)
{
#if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#elif defined(__gfx908__)
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec,
0,
0,
0);
});
static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec,
0,
0,
0);
});
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
@@ -258,20 +347,21 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
});
return c_vec;
#else
ignore = a_vec;
ignore = b_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
return CVecType{0.f};
#endif
}
};
// FP8
template <typename AType_, typename BType_>
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
{
using ADataType = AType_;
using BDataType = BType_;
using CDataType = float;
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = AType_;
using BDataType = BType_;
using CDataType = float;
using AVecType = ext_vector_t<ADataType, 8>;
using BVecType = ext_vector_t<BDataType, 8>;
@@ -292,38 +382,120 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "v", "v")
}
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "a", "v")
}
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "v", "v")
}
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "a", "v")
}
}
else
{
#if defined(__gfx94__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for<0, 8, 1>{}([&](auto k) {
float a_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
.template get_as<ADataType>()[number<k>{}]);
float b_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
.template get_as<BDataType>()[number<k>{}]);
static_for<0, 8, 1>{}([&](auto k) {
float a_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
.template get_as<ADataType>()[number<k>{}]);
float b_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
.template get_as<BDataType>()[number<k>{}]);
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
});
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
});
#else
ignore = c_vec;
ignore = a_vec;
ignore = b_vec;
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
@@ -356,20 +528,97 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
});
return c_vec;
#else
ignore = a_vec;
ignore = b_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
return CVecType{0.f};
#endif
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t>;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, bf8_t>;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, bf8_t, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, fp8_t>;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, fp8_t, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, bf8_t>;
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, bf8_t, Ctrl_>;
// int8
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int32_t;
using AVecType = ext_vector_t<ADataType, 8>;
using BVecType = ext_vector_t<BDataType, 8>;
using CVecType = ext_vector_t<CDataType, 16>;
static constexpr index_t kM = 32;
static constexpr index_t kN = 32;
static constexpr index_t kK = 16;
static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 32;
static constexpr index_t kCM0PerLane = 4;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
else
{
#if defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_i32_32x32x8i8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for<0, 8, 1>{}([&](auto k) {
float a_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
.template get_as<ADataType>()[number<k>{}]);
float b_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
.template get_as<BDataType>()[number<k>{}]);
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
});
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
CVecType c_vec{0};
operator()(c_vec, a_vec, b_vec);
return c_vec;
}
};
#undef DISPATCH_MFMA_
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher;
// clang-format off
// fp16
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
// bf16
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
// fp8
template<> struct WarpGemmMfmaDispatcher<fp8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<fp8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<fp8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
template<> struct WarpGemmMfmaDispatcher<fp8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<bf8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
template<> struct WarpGemmMfmaDispatcher<bf8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<bf8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
template<> struct WarpGemmMfmaDispatcher<bf8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
// clang-format on
} // namespace impl

View File

@@ -31,11 +31,21 @@ struct WarpGemmImpl
using BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>;
using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>;
CK_TILE_DEVICE void operator()(CWarpTensor& c, const AWarpTensor& a, const BWarpTensor& b) const
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access()
{
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>;
return WarpGemmAttribute_::get_num_of_access();
}
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
CK_TILE_DEVICE void
operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
{
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
@@ -44,18 +54,49 @@ struct WarpGemmImpl
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
// c_vec += a_vec * b_vec
WarpGemmAttribute{}(c_vec, a_vec, b_vec);
WarpGemmAttribute{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
}
CK_TILE_DEVICE auto operator()(const AWarpTensor& a, const BWarpTensor& b) const
template <typename CTensor,
typename ATensor,
typename BTensor,
index_t i_subk,
bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CTensor& c,
const ATensor& a,
const BTensor& b,
number<i_subk>,
bool_constant<post_nop_> = {}) const
{
CWarpTensor c;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
// c_vec += a_vec * b_vec
WarpGemmAttribute{}(c_vec, a_vec, b_vec, number<i_subk>{}, bool_constant<post_nop_>{});
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
}
template <typename ATensor, typename BTensor>
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
{
using CTensor = CWarpTensor;
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
CTensor c;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};

View File

@@ -1,11 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"