mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] fused-moe first version (#1634)
* moe pipeline * update code * compile OK * update * update cpu reference * update pipeline_gemm0 * compiler ok * update pipeline * rename to ex pipeline * block-asm * update * update * update first gemm ok * compute correct * update file structure * update README * update * update * update code * update API * return unsupport case * add comment * update readme * update * uncomment * update * fix build err --------- Co-authored-by: valarLip <340077269@qq.com>
This commit is contained in:
@@ -40,7 +40,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
|
||||
else if(t.permute.compare("0,1,3,4,2,5") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
|
||||
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
@@ -83,7 +83,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
|
||||
else if(t.permute.compare("0,1,3,4,2,5") == 0)
|
||||
{
|
||||
constexpr matrix_core_permute_style pstyle =
|
||||
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
|
||||
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
|
||||
using Kernel =
|
||||
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
|
||||
|
||||
|
||||
@@ -42,8 +42,8 @@ enum class matrix_core_permute_style
|
||||
{
|
||||
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
|
||||
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
|
||||
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
|
||||
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
|
||||
b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
|
||||
b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
|
||||
};
|
||||
|
||||
// assume this is B matrix, originally we have batch*n*k
|
||||
@@ -203,7 +203,7 @@ struct matrix_core_swizzle_kernel
|
||||
else
|
||||
{
|
||||
// clang-format off
|
||||
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
|
||||
// b_nr_kr_kw_nw_kv or b_nr_kr_waveflatten
|
||||
constexpr index_t Kv = Alignment;
|
||||
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
@@ -332,7 +332,7 @@ struct matrix_core_swizzle_kernel
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return tmp_1;
|
||||
#else
|
||||
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
|
||||
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
@@ -376,13 +376,13 @@ struct matrix_core_swizzle_kernel
|
||||
else
|
||||
{
|
||||
#if MERGE_2D_013425
|
||||
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
|
||||
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
|
||||
return make_tile_window(dst_view,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
{i_n * NPerBlock, i_k * KPerBlock},
|
||||
get_dst_dist());
|
||||
#else
|
||||
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
|
||||
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
|
||||
@@ -264,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
|
||||
{
|
||||
// permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
|
||||
// b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
|
||||
matrix_core_swizzle_traits t;
|
||||
t.data_type = data_type;
|
||||
t.permute = arg_parser.get_str("perm");
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/moe_sorting.hpp"
|
||||
#include "ck_tile/ops/fused_moe.hpp"
|
||||
|
||||
struct moe_sorting_trait
|
||||
{
|
||||
|
||||
19
example/ck_tile/15_fused_moe/CMakeLists.txt
Normal file
19
example/ck_tile/15_fused_moe/CMakeLists.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding ${TILE_EXAPMLE_FUSED_MOE}")
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp)
|
||||
target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS})
|
||||
|
||||
set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a
|
||||
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4) # rta
|
||||
# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1)
|
||||
# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
|
||||
target_compile_options(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS})
|
||||
69
example/ck_tile/15_fused_moe/README.md
Normal file
69
example/ck_tile/15_fused_moe/README.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# fused-moe
|
||||
Implementing the fused-moe block operator using ck-tile. This is a scatter/gather-group-gemm based solution, similiar to that of [vllm moe](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), but we introduce more kernel fusion to boost performance
|
||||

|
||||
|
||||
The benifit of this fused-moe:
|
||||
* 1.5~2x perf boost compared with current vllm solution
|
||||
* zero workspace to reduce memory footprint
|
||||
* much less kernel instance, easy to maintain
|
||||
|
||||
# Implementation and feature support
|
||||
## moe-sorting
|
||||
this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic)
|
||||
|
||||
## moe-gemm
|
||||
`moe-gemm` is a group-gemm based back-to-back gemm, where the row-id of input token comes from another buffer. Naive understanding of fused-moe is from token-by-token view as below picture:
|
||||

|
||||
After `moe-sorting`, we can view this algorithm as expert-by-expert, as below:
|
||||

|
||||
|
||||
## optimization
|
||||
summary of the key design of this fused-moe operator:
|
||||
* fuse 2 group-gemm + activation + `topk-weight` multiply into single kernel, using atomic for 2nd gemm accumualation
|
||||
* fuse buffer-zeroing in `moe-sorgin`, user no longer need call extra torch.zero() for the out buffer
|
||||
* fused scatter-gather for row index(same as vllm)
|
||||
* pre-shuffle B matric(weight) to maximize memory throughput. input(activation) keep original layout `[batch, hidden]`.
|
||||
* extrem optimized pipeline using block-inline-asm(we call it `micro-kernel` or `uk`), while not breaking the *composable* design of ck
|
||||
|
||||
##
|
||||
```
|
||||
// [indexing implementation-1]
|
||||
// using M_a as constexpr block_size to partition all tokens into different slices
|
||||
// each slice map to one expert, and one expert can have multiple slices
|
||||
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
|
||||
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
|
||||
// tok-0 tok-1 tok-2 tok-3 tok-4
|
||||
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
|
||||
//
|
||||
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
|
||||
// * this could be larger than actual, since actual tokens are on GPU
|
||||
//
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
|
||||
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
|
||||
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
|
||||
//
|
||||
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
|
||||
//
|
||||
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
|
||||
// * length is (max_num_tokens_padded + block_size - 1) / block_size
|
||||
//
|
||||
// num_tokens_post_padded_ptr : [28]
|
||||
// num_sorted_tiles_ptr : [7]
|
||||
//
|
||||
// * different from vLLM
|
||||
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
|
||||
// 2)need sorted_weight_ptr
|
||||
// 3) use num_sorted_tiles_ptr, already divided by M_a
|
||||
//
|
||||
// * below used for indexing
|
||||
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
|
||||
// 2) sorted_weight_ptr
|
||||
// 3) sorted_expert_ids_ptr
|
||||
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
|
||||
//
|
||||
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
|
||||
```
|
||||
52
example/ck_tile/15_fused_moe/fused_moe.hpp
Normal file
52
example/ck_tile/15_fused_moe/fused_moe.hpp
Normal file
@@ -0,0 +1,52 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fused_moesorting.hpp"
|
||||
#include "fused_moegemm.hpp"
|
||||
|
||||
struct fused_moe_args
|
||||
{
|
||||
const void* a_ptr; // [m, k], input token
|
||||
const void* a_scale_ptr; // [m, 1], token scale
|
||||
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
|
||||
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
|
||||
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
|
||||
const void* d_scale_ptr; // [e, 1, k], down scale
|
||||
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
|
||||
void* o_ptr; // [m, k], output token (no need to do zeroing)
|
||||
|
||||
const void* topk_ids_ptr; // [tokens, topk]
|
||||
const void* topk_weight_ptr; // [tokens, topk]
|
||||
void* sorted_token_ids_ptr; // [max_num_tokens_padded]
|
||||
void* sorted_weight_ptr; // [max_num_tokens_padded]
|
||||
void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
|
||||
void* num_sorted_tiles_ptr; // [1]
|
||||
|
||||
ck_tile::index_t block_m; // block_m, used to devide the input
|
||||
ck_tile::index_t hidden_size; // k
|
||||
ck_tile::index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
|
||||
ck_tile::index_t num_tokens; // input number of tokens for current iteration
|
||||
ck_tile::index_t num_experts; // number of groups
|
||||
ck_tile::index_t topk; // need this?
|
||||
|
||||
ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size
|
||||
};
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fused_moe_traits
|
||||
{
|
||||
std::string prec_i; // input precision
|
||||
std::string prec_w; // weight precision
|
||||
std::string prec_o; // output precision
|
||||
std::string prec_st; // token scale data type
|
||||
std::string prec_sw; // weight scale data type
|
||||
std::string prec_sq; // smooth quant scale
|
||||
std::string prec_kw; // topk-weight data type
|
||||
int block_m;
|
||||
int gate_only;
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&);
|
||||
84
example/ck_tile/15_fused_moe/fused_moegemm.hpp
Normal file
84
example/ck_tile/15_fused_moe/fused_moegemm.hpp
Normal file
@@ -0,0 +1,84 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/fused_moe.hpp"
|
||||
#include <string>
|
||||
|
||||
// this is only a convenient structure for creating an example
|
||||
// this is not part of the host API
|
||||
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig;
|
||||
|
||||
template <typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using GDataType = ck_tile::bf16_t;
|
||||
using DDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
};
|
||||
|
||||
template <typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, ST, SW, SQ, KW>
|
||||
{
|
||||
using ADataType = ck_tile::fp16_t;
|
||||
using GDataType = ck_tile::fp16_t;
|
||||
using DDataType = ck_tile::fp16_t;
|
||||
using AccDataType = float;
|
||||
using ODataType = ck_tile::fp16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
};
|
||||
|
||||
template <typename ST, typename SW, typename SQ, typename KW>
|
||||
struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>
|
||||
{
|
||||
using ADataType = ck_tile::int8_t;
|
||||
using GDataType = ck_tile::int8_t;
|
||||
using DDataType = ck_tile::int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
|
||||
using IndexDataType = ck_tile::index_t;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
struct fused_moegemm_args : public ck_tile::FusedMoeGemmHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct fused_moegemm_traits
|
||||
{
|
||||
std::string prec_i; // input precision
|
||||
std::string prec_w; // weight precision
|
||||
std::string prec_o; // output precision
|
||||
std::string prec_st; // token scale data type
|
||||
std::string prec_sw; // weight scale data type
|
||||
std::string prec_sq; // smooth quant scale
|
||||
std::string prec_kw; // topk-weight data type
|
||||
int block_m;
|
||||
int gate_only;
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
float fused_moegemm(fused_moegemm_traits, fused_moegemm_args, const ck_tile::stream_config&);
|
||||
20
example/ck_tile/15_fused_moe/fused_moesorting.hpp
Normal file
20
example/ck_tile/15_fused_moe/fused_moesorting.hpp
Normal file
@@ -0,0 +1,20 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/fused_moe.hpp"
|
||||
|
||||
struct fused_moesorting_trait
|
||||
{
|
||||
std::string index_type;
|
||||
std::string weight_type; // currently always float
|
||||
};
|
||||
|
||||
struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s);
|
||||
80
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
Normal file
80
example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp
Normal file
@@ -0,0 +1,80 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fused_moe.hpp"
|
||||
|
||||
float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s)
|
||||
{
|
||||
auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1};
|
||||
|
||||
auto o_data_bytes = [&]() {
|
||||
if(t.prec_o == "fp32")
|
||||
return 4;
|
||||
else if(t.prec_o == "fp16" || t.prec_o == "bf16")
|
||||
return 2;
|
||||
else if(t.prec_o == "int8" || t.prec_o == "fp8")
|
||||
return 1;
|
||||
return 1;
|
||||
}();
|
||||
|
||||
auto t0 = fused_moesorting_trait{"int32", "fp32"};
|
||||
auto a0 = fused_moesorting_args{
|
||||
a.topk_ids_ptr, // const void* p_topk_ids;
|
||||
a.topk_weight_ptr, // const void* p_weights;
|
||||
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
|
||||
a.sorted_weight_ptr, // void* p_sorted_weights;
|
||||
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
|
||||
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
|
||||
a.o_ptr, // void* p_moe_buf;
|
||||
a.num_tokens, // index_t tokens;
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes;
|
||||
};
|
||||
|
||||
auto t1 = fused_moegemm_traits{t.prec_i,
|
||||
t.prec_w,
|
||||
t.prec_o,
|
||||
t.prec_st,
|
||||
t.prec_sw,
|
||||
t.prec_sq,
|
||||
t.prec_kw,
|
||||
t.block_m,
|
||||
t.gate_only,
|
||||
t.fused_quant};
|
||||
auto a1 = fused_moegemm_args{
|
||||
a.a_ptr, // const void* a_ptr;
|
||||
a.a_scale_ptr, // const void* a_scale_ptr;
|
||||
a.g_ptr, // const void* g_ptr;
|
||||
a.d_ptr, // const void* d_ptr;
|
||||
a.g_scale_ptr, // const void* g_scale_ptr;
|
||||
a.d_scale_ptr, // const void* d_scale_ptr;
|
||||
a.y_smooth_scale_ptr, // const void* y_smooth_scale_ptr;
|
||||
a.o_ptr, // void* o_ptr;
|
||||
a.sorted_token_ids_ptr, // const void* sorted_token_ids_ptr;
|
||||
a.sorted_weight_ptr, // const void* sorted_weight_ptr;
|
||||
a.sorted_expert_ids_ptr, // const void* sorted_expert_ids_ptr;
|
||||
a.num_sorted_tiles_ptr, // const void* num_sorted_tiles_ptr;
|
||||
a.hidden_size, // index_t hidden_size;
|
||||
a.intermediate_size, // index_t intermediate_size;
|
||||
a.num_tokens, // index_t num_tokens;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
a.stride_token // index_t stride_token;
|
||||
};
|
||||
|
||||
float r0 = -1;
|
||||
float r1 = -1;
|
||||
|
||||
float r = ck_tile::launch_kernel(
|
||||
s,
|
||||
[=, &r0](const ck_tile::stream_config&) { r0 = fused_moesorting(t0, a0, s_sub); },
|
||||
[=, &r1](const ck_tile::stream_config&) { r1 = fused_moegemm(t1, a1, s_sub); });
|
||||
|
||||
// keep unsupported case return negative
|
||||
if(r0 < 0 || r1 < 0)
|
||||
return -1;
|
||||
|
||||
return r;
|
||||
}
|
||||
33
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
Normal file
33
example/ck_tile/15_fused_moe/instances/fused_moegemm_api.cpp
Normal file
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "fused_moegemm.hpp"
|
||||
#include "fused_moegemm_api_traits.hpp"
|
||||
|
||||
// Note: this internal API only declare, not define here, otherwise will block `make -j`
|
||||
template <typename Traits_>
|
||||
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
template <ck_tile::index_t... Is>
|
||||
using S = ck_tile::sequence<Is...>;
|
||||
|
||||
float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s)
|
||||
{
|
||||
// clang-format off
|
||||
float r = -1;
|
||||
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
|
||||
{
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
|
||||
{
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
// clang-format on
|
||||
return r;
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fused_moegemm_api_traits.hpp"
|
||||
#include "ck_tile/ops/fused_moe.hpp"
|
||||
#include <iostream>
|
||||
|
||||
template <ck_tile::index_t... Is>
|
||||
using S = ck_tile::sequence<Is...>;
|
||||
|
||||
// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j
|
||||
template <typename Ts_>
|
||||
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
|
||||
{
|
||||
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
|
||||
using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0,
|
||||
typename Ts_::BlockTile_1,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0>;
|
||||
using f_problem =
|
||||
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
|
||||
typename Ts_::GDataType,
|
||||
typename Ts_::DDataType,
|
||||
typename Ts_::AccDataType,
|
||||
typename Ts_::ODataType,
|
||||
typename Ts_::AScaleDataType,
|
||||
typename Ts_::GScaleDataType,
|
||||
typename Ts_::DScaleDataType,
|
||||
typename Ts_::YSmoothScaleDataType,
|
||||
typename Ts_::TopkWeightDataType,
|
||||
typename Ts_::IndexDataType,
|
||||
ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded
|
||||
f_shape,
|
||||
f_traits>;
|
||||
|
||||
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
|
||||
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
|
||||
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
|
||||
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
|
||||
|
||||
const dim3 grids = f_kernel::GridSize(a);
|
||||
constexpr dim3 blocks = f_kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
static int printed = 0;
|
||||
|
||||
auto kargs = f_kernel::MakeKargs(a);
|
||||
if(s.log_level_ > 0 && printed == 0)
|
||||
{
|
||||
std::cout << ", " << f_kernel::GetName() << std::flush;
|
||||
printed = 1;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename I,
|
||||
typename W,
|
||||
typename O,
|
||||
typename ST,
|
||||
typename SW,
|
||||
typename SQ,
|
||||
typename KW,
|
||||
typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down>
|
||||
typename WarpPerBlock_,
|
||||
typename WarpTile_, // seq<*,*,*>, used to select mfma
|
||||
ck_tile::index_t GateOnly_ = 0,
|
||||
ck_tile::index_t FusedQuant_ = 0>
|
||||
struct fmoe_ // traits, ugly name, only used for internal
|
||||
{
|
||||
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
|
||||
|
||||
using ADataType = ck_tile::remove_cvref_t<typename TypeConfig::ADataType>;
|
||||
using GDataType = ck_tile::remove_cvref_t<typename TypeConfig::GDataType>;
|
||||
using DDataType = ck_tile::remove_cvref_t<typename TypeConfig::DDataType>;
|
||||
using AccDataType = ck_tile::remove_cvref_t<typename TypeConfig::AccDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename TypeConfig::ODataType>;
|
||||
using AScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::AScaleDataType>;
|
||||
using GScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::GScaleDataType>;
|
||||
using DScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::DScaleDataType>;
|
||||
using YSmoothScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>;
|
||||
using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
|
||||
using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>;
|
||||
|
||||
static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token
|
||||
static constexpr ck_tile::index_t BI_ =
|
||||
BlockTIle_::at(ck_tile::number<1>{}); // block intermediate
|
||||
static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden
|
||||
static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down
|
||||
|
||||
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>;
|
||||
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>;
|
||||
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
static constexpr ck_tile::index_t GateOnly = GateOnly_;
|
||||
static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
|
||||
};
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "fused_moegemm.hpp"
|
||||
#include "fused_moegemm_api_traits.hpp"
|
||||
#include "fused_moegemm_api_internal.hpp"
|
||||
|
||||
// clang-format off
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "fused_moegemm.hpp"
|
||||
#include "fused_moegemm_api_traits.hpp"
|
||||
#include "fused_moegemm_api_internal.hpp"
|
||||
|
||||
// clang-format off
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "fused_moesorting.hpp"
|
||||
|
||||
#define MOE_SORTING_DISPATCH(unroll_num_) \
|
||||
constexpr ck_tile::index_t unroll_num = unroll_num_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
const auto lds_bytes = kernel::GetSmemSize(a); \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
|
||||
return ave_time;
|
||||
|
||||
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.weight_type == "fp32" && t.index_type == "int32")
|
||||
{
|
||||
if(a.num_experts > 127)
|
||||
{
|
||||
printf("lds size exceed, only support experts <127 \n");
|
||||
return -1;
|
||||
}
|
||||
if(a.moe_buf_bytes % 16)
|
||||
{
|
||||
printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes);
|
||||
return -1;
|
||||
}
|
||||
using index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64);
|
||||
switch(smem_io_unroll_num)
|
||||
{
|
||||
case(1): {
|
||||
MOE_SORTING_DISPATCH(1);
|
||||
}
|
||||
case(2): {
|
||||
MOE_SORTING_DISPATCH(2);
|
||||
}
|
||||
case(3): {
|
||||
MOE_SORTING_DISPATCH(3);
|
||||
}
|
||||
case(5): {
|
||||
MOE_SORTING_DISPATCH(5);
|
||||
}
|
||||
case(6): {
|
||||
MOE_SORTING_DISPATCH(6);
|
||||
}
|
||||
case(7): {
|
||||
MOE_SORTING_DISPATCH(7);
|
||||
}
|
||||
case(8): {
|
||||
MOE_SORTING_DISPATCH(8);
|
||||
}
|
||||
case(9): {
|
||||
MOE_SORTING_DISPATCH(9);
|
||||
}
|
||||
case(10): {
|
||||
MOE_SORTING_DISPATCH(10);
|
||||
}
|
||||
case(11): {
|
||||
MOE_SORTING_DISPATCH(11);
|
||||
}
|
||||
default: {
|
||||
MOE_SORTING_DISPATCH(4);
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
603
example/ck_tile/15_fused_moe/main.cpp
Normal file
603
example/ck_tile/15_fused_moe/main.cpp
Normal file
@@ -0,0 +1,603 @@
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "fused_moe.hpp"
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
// mfma_type, 0:32x32, 1:16x16
|
||||
// TODO: padding?
|
||||
template <typename T>
|
||||
auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0)
|
||||
{
|
||||
assert(t.get_lengths().size() == 3);
|
||||
int b_ = t.get_lengths()[0];
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[2];
|
||||
if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 16, 2, 8});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 32, 4, 8});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 32, 2, 16});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1)
|
||||
{
|
||||
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 64, 4, 16});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
template <typename IndexType>
|
||||
void topid_unique_gen(
|
||||
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
|
||||
{
|
||||
size_t total_size = topk * tokens;
|
||||
std::srand(seed);
|
||||
std::set<IndexType> unique_set;
|
||||
IndexType current_v;
|
||||
for(size_t i = 0; i < total_size; i++)
|
||||
{
|
||||
if(i % topk == 0)
|
||||
{
|
||||
unique_set.clear();
|
||||
}
|
||||
current_v = std::rand() % num_expert;
|
||||
while(unique_set.find(current_v) != unique_set.end())
|
||||
{
|
||||
current_v = std::rand() % num_expert;
|
||||
}
|
||||
unique_set.insert(current_v);
|
||||
host_tensor[i] = current_v;
|
||||
}
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("t", "128", "num input tokens")
|
||||
.insert("e", "32", "num of experts")
|
||||
.insert("k", "5", "topk")
|
||||
.insert("h", "8192", "hidden_size of this model")
|
||||
.insert("i", "8192", "intermediate_size between 2 gemms of FFN")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
|
||||
.insert("bm", "32", "blocking factor for sorted tokens")
|
||||
.insert("tp", "8", "tensor parallel size")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec_i", "bf16", "input precision")
|
||||
.insert("prec_w", "bf16", "weight precision")
|
||||
.insert("prec_o", "bf16", "output precision")
|
||||
.insert("prec_st", "auto", "token scale data type. auto will set to fp32")
|
||||
.insert("prec_sw", "auto", "weight scale data type. auto will set to fp32")
|
||||
.insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32")
|
||||
.insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32")
|
||||
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
||||
.insert(
|
||||
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
|
||||
.insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm")
|
||||
.insert("balance",
|
||||
"0",
|
||||
"if set to 1, will try balance the expert in topk-ids(convenient for testing)")
|
||||
.insert("init",
|
||||
"2",
|
||||
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
|
||||
"normalized(slow)")
|
||||
.insert("seed", "11939", "seed used to do random")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type,
|
||||
// SQ:smooth-quant-type, KW:topk-weight-type
|
||||
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t tokens = arg_parser.get_int("t");
|
||||
ck_tile::index_t experts = arg_parser.get_int("e");
|
||||
ck_tile::index_t topk = arg_parser.get_int("k");
|
||||
ck_tile::index_t hidden_size = arg_parser.get_int("h");
|
||||
ck_tile::index_t intermediate_size = arg_parser.get_int("i");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
ck_tile::index_t block_m = arg_parser.get_int("bm");
|
||||
if(stride < 0)
|
||||
stride = hidden_size;
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_w = arg_parser.get_str("prec_w");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_st = arg_parser.get_str("prec_st");
|
||||
std::string prec_sw = arg_parser.get_str("prec_sw");
|
||||
std::string prec_sq = arg_parser.get_str("prec_sq");
|
||||
std::string prec_kw = arg_parser.get_str("prec_kw");
|
||||
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
|
||||
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
|
||||
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
|
||||
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int gate_only = arg_parser.get_int("gate_only");
|
||||
int api = arg_parser.get_int("api");
|
||||
int balance = arg_parser.get_int("balance");
|
||||
int tp = arg_parser.get_int("tp");
|
||||
int init = arg_parser.get_int("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
|
||||
// w0 (Gate+Up or Gate only, N size)
|
||||
ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp;
|
||||
// w1 (Down, N size)
|
||||
ck_tile::index_t shared_intermediate_size_1 = intermediate_size / tp;
|
||||
|
||||
auto prec_str = [&]() {
|
||||
auto base_str = prec_i;
|
||||
if(prec_i != prec_w)
|
||||
base_str += "x" + prec_w;
|
||||
if(prec_i != prec_o)
|
||||
base_str += "=" + prec_o;
|
||||
if(fused_quant != 0)
|
||||
{
|
||||
base_str += std::string("(") + prec_st + "|" + prec_sw + "|" + prec_sq + ")";
|
||||
}
|
||||
return base_str;
|
||||
}();
|
||||
auto api_str = [&]() {
|
||||
if(api == 0)
|
||||
return std::string("fmoe");
|
||||
else if(api == 1)
|
||||
return std::string("moeg");
|
||||
else if(api == 2)
|
||||
return std::string("moes");
|
||||
return std::string("");
|
||||
}();
|
||||
|
||||
auto stride_str = [&]() {
|
||||
if(stride == hidden_size)
|
||||
return std::string("");
|
||||
else
|
||||
return std::string(", st:") + std::to_string(stride);
|
||||
}();
|
||||
|
||||
std::cout << "[" << api_str << "|" << prec_str << "]"
|
||||
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
|
||||
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
|
||||
<< ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
|
||||
<< ", go:" << gate_only << ", q:" << fused_quant << std::flush;
|
||||
|
||||
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using GDataType = typename TypeConfig::GDataType;
|
||||
using DDataType = typename TypeConfig::DDataType;
|
||||
using AccDataType = typename TypeConfig::AccDataType;
|
||||
using ODataType = typename TypeConfig::ODataType;
|
||||
using AScaleDataType = typename TypeConfig::AScaleDataType;
|
||||
using GScaleDataType = typename TypeConfig::GScaleDataType;
|
||||
using DScaleDataType = typename TypeConfig::DScaleDataType;
|
||||
using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType;
|
||||
using TopkWeightDataType = typename TypeConfig::TopkWeightDataType;
|
||||
using IndexDataType = typename TypeConfig::IndexDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
|
||||
ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size});
|
||||
ck_tile::HostTensor<DDataType> d_host({experts, hidden_size, shared_intermediate_size_1});
|
||||
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
|
||||
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
|
||||
ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size_0});
|
||||
ck_tile::HostTensor<DScaleDataType> sd_host({shared_intermediate_size_1});
|
||||
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant
|
||||
ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort
|
||||
ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort
|
||||
|
||||
int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
|
||||
ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<TopkWeightDataType> sorted_weight_host({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<IndexDataType> sorted_expert_ids_host(
|
||||
{(max_num_tokens_padded + block_m - 1) / block_m});
|
||||
ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1});
|
||||
|
||||
if(init == 0)
|
||||
{
|
||||
ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host);
|
||||
ck_tile::FillStepRange<GDataType>{-.5f, .5f, 0.01f}(g_host);
|
||||
ck_tile::FillStepRange<DDataType, false>{.5f, -.5f, -0.01f}(d_host);
|
||||
ck_tile::FillStepRange<AScaleDataType>{0.f, 1.f, 0.01f}(sa_host);
|
||||
ck_tile::FillStepRange<GScaleDataType>{0.f, 1.f, 0.01f}(sg_host);
|
||||
ck_tile::FillStepRange<DScaleDataType>{0.f, 1.f, 0.01f}(sd_host);
|
||||
ck_tile::FillStepRange<YSmoothScaleDataType>{0.f, 1.f, 0.01f}(sy_host);
|
||||
ck_tile::FillStepRange<TopkWeightDataType>{-.5f, .5f, 0.01f}(topk_weight_host);
|
||||
}
|
||||
else if(init == 1)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f, seed, true}(a_host);
|
||||
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f, seed, true}(g_host);
|
||||
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f, seed, true}(d_host);
|
||||
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed, true}(sa_host);
|
||||
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host);
|
||||
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host);
|
||||
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f, seed, true}(sy_host);
|
||||
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f, seed, true}(
|
||||
topk_weight_host);
|
||||
}
|
||||
else if(init == 2)
|
||||
{
|
||||
ck_tile::FillNormalDistribution<ADataType>{0.f, 1.f, seed, true}(a_host);
|
||||
ck_tile::FillNormalDistribution<GDataType>{0.f, 1.f, seed, true}(g_host);
|
||||
ck_tile::FillNormalDistribution<DDataType>{0.f, 1.f, seed, true}(d_host);
|
||||
ck_tile::FillNormalDistribution<AScaleDataType>{0.f, 1.f, seed, true}(sa_host);
|
||||
ck_tile::FillNormalDistribution<GScaleDataType>{0.f, 1.f, seed, true}(sg_host);
|
||||
ck_tile::FillNormalDistribution<DScaleDataType>{0.f, 1.f, seed, true}(sd_host);
|
||||
ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host);
|
||||
ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host);
|
||||
}
|
||||
|
||||
// permute weight
|
||||
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
|
||||
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
|
||||
|
||||
// do moe sorting
|
||||
if(balance)
|
||||
{
|
||||
int e_cnt = 0;
|
||||
for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++)
|
||||
{
|
||||
topk_ids_host.mData[i] = e_cnt;
|
||||
e_cnt++;
|
||||
if(e_cnt >= experts)
|
||||
e_cnt = 0;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
|
||||
}
|
||||
|
||||
// leave it here for future debug purpose
|
||||
#if 0
|
||||
a_host.loadtxt("../../ater/input_torch.txt");
|
||||
|
||||
topk_ids_host.loadtxt("../../ater/topk_ids_torch.txt", "int");
|
||||
// topk_ids_host.savetxt("topk_ids_2.txt");
|
||||
topk_weight_host.loadtxt("../../ater/topk_weights_torch.txt", "float");
|
||||
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
|
||||
|
||||
g_host.loadtxt("../../ater/w1_torch.txt", "float");
|
||||
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
|
||||
d_host.loadtxt("../../ater/w2_torch.txt", "float");
|
||||
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
|
||||
|
||||
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
|
||||
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
|
||||
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
|
||||
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
std::cout << "sorted_token_ids_host:" << sorted_token_ids_host << std::endl;
|
||||
std::cout << "num_sorted_tiles_host:" << num_sorted_tiles_host << std::endl;
|
||||
std::cout << "sorted_expert_ids_host:" << sorted_expert_ids_host << std::endl;
|
||||
std::cout << "topk_weight_host:" << topk_weight_host << std::endl;
|
||||
std::cout << "sorted_weight_host:" << sorted_weight_host << std::endl;
|
||||
#endif
|
||||
auto cal_tflops = [&](auto ms) {
|
||||
double flop_gemm_0 =
|
||||
2 * static_cast<double>(tokens) * topk * shared_intermediate_size_0 * hidden_size;
|
||||
double flop_gemm_1 =
|
||||
2 * static_cast<double>(tokens) * topk * shared_intermediate_size_1 * hidden_size;
|
||||
return (flop_gemm_0 + flop_gemm_1) / (static_cast<double>(ms) * 1e-3) / 1e12;
|
||||
};
|
||||
|
||||
// TODO: this method we use expert-by-expert view, just for reference
|
||||
auto cal_tbps = [&](auto ms) {
|
||||
double token_bytes =
|
||||
static_cast<double>(tokens) * topk / experts * hidden_size * sizeof(ADataType);
|
||||
double w0_bytes = static_cast<double>(shared_intermediate_size_0) * experts * hidden_size *
|
||||
sizeof(GDataType);
|
||||
double w1_bytes = static_cast<double>(shared_intermediate_size_1) * experts * hidden_size *
|
||||
sizeof(DDataType);
|
||||
double o_bytes =
|
||||
static_cast<double>(tokens) * topk / experts * hidden_size * sizeof(ODataType);
|
||||
double topk_weights_bytes = static_cast<double>(tokens) * topk * sizeof(TopkWeightDataType);
|
||||
// ignore index, they are too small
|
||||
|
||||
return (token_bytes + w0_bytes + w1_bytes + o_bytes + topk_weights_bytes) /
|
||||
(static_cast<double>(ms) * 1e-3) / 1e12;
|
||||
};
|
||||
|
||||
if(api == 0)
|
||||
{
|
||||
ck_tile::DeviceMem a_buf(a_host);
|
||||
ck_tile::DeviceMem g_perm_buf(g_perm_host);
|
||||
ck_tile::DeviceMem d_perm_buf(d_perm_host);
|
||||
ck_tile::DeviceMem sa_buf(sa_host);
|
||||
ck_tile::DeviceMem sg_buf(sg_host);
|
||||
ck_tile::DeviceMem sd_buf(sd_host);
|
||||
ck_tile::DeviceMem sy_buf(sy_host);
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem topk_ids_buf(topk_ids_host);
|
||||
ck_tile::DeviceMem topk_weight_buf(topk_weight_host);
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_buf(
|
||||
sorted_token_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sorted_expert_ids_buf(
|
||||
sorted_expert_ids_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem num_sorted_tiles_buf(
|
||||
num_sorted_tiles_host.get_element_space_size_in_bytes());
|
||||
|
||||
fused_moe_traits traits{prec_i,
|
||||
prec_w,
|
||||
prec_o,
|
||||
prec_st,
|
||||
prec_sw,
|
||||
prec_sq,
|
||||
prec_kw,
|
||||
block_m,
|
||||
gate_only,
|
||||
fused_quant};
|
||||
|
||||
fused_moe_args args{a_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
|
||||
g_perm_buf.GetDeviceBuffer(),
|
||||
d_perm_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
|
||||
o_buf.GetDeviceBuffer(),
|
||||
topk_ids_buf.GetDeviceBuffer(),
|
||||
topk_weight_buf.GetDeviceBuffer(),
|
||||
sorted_token_ids_buf.GetDeviceBuffer(),
|
||||
sorted_weight_buf.GetDeviceBuffer(),
|
||||
sorted_expert_ids_buf.GetDeviceBuffer(),
|
||||
num_sorted_tiles_buf.GetDeviceBuffer(),
|
||||
block_m,
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride};
|
||||
float ave_time = fused_moe(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
if(ave_time < 0)
|
||||
{
|
||||
std::cout << " not supported!" << std::endl << std::flush;
|
||||
return false;
|
||||
}
|
||||
|
||||
// float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, "
|
||||
<< cal_tbps(ave_time) << " TB/s" << std::flush;
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
|
||||
topk_ids_host,
|
||||
topk_weight_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host.mData[0],
|
||||
experts,
|
||||
block_m);
|
||||
|
||||
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
|
||||
a_host,
|
||||
g_host,
|
||||
d_host,
|
||||
sa_host,
|
||||
sg_host,
|
||||
sd_host,
|
||||
sy_host,
|
||||
o_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host,
|
||||
topk_ids_host,
|
||||
block_m,
|
||||
tokens,
|
||||
experts,
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
topk,
|
||||
gate_only);
|
||||
|
||||
auto o_dev = o_buf.ToHost<ODataType>();
|
||||
// o_dev.savetxt("gpu-out.txt", "float");
|
||||
auto [rtol, atol] = get_elimit<ADataType>();
|
||||
pass &= ck_tile::check_err(
|
||||
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
|
||||
}
|
||||
std::cout << std::flush << std::endl;
|
||||
return pass;
|
||||
}
|
||||
else if(api == 1)
|
||||
{
|
||||
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
|
||||
topk_ids_host,
|
||||
topk_weight_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host.mData[0],
|
||||
experts,
|
||||
block_m);
|
||||
|
||||
// done, preparing GPU buffer
|
||||
ck_tile::DeviceMem a_buf(a_host);
|
||||
ck_tile::DeviceMem g_perm_buf(g_perm_host);
|
||||
ck_tile::DeviceMem d_perm_buf(d_perm_host);
|
||||
ck_tile::DeviceMem sa_buf(sa_host);
|
||||
ck_tile::DeviceMem sg_buf(sg_host);
|
||||
ck_tile::DeviceMem sd_buf(sd_host);
|
||||
ck_tile::DeviceMem sy_buf(sy_host);
|
||||
ck_tile::DeviceMem o_buf(o_host);
|
||||
|
||||
// manually clear output buffer for atomic
|
||||
o_buf.SetZero();
|
||||
//
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_buf(sorted_token_ids_host);
|
||||
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host);
|
||||
ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host);
|
||||
ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host);
|
||||
|
||||
fused_moegemm_traits traits{prec_i,
|
||||
prec_w,
|
||||
prec_o,
|
||||
prec_st,
|
||||
prec_sw,
|
||||
prec_sq,
|
||||
prec_kw,
|
||||
block_m,
|
||||
gate_only,
|
||||
fused_quant};
|
||||
|
||||
fused_moegemm_args args{a_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
|
||||
g_perm_buf.GetDeviceBuffer(),
|
||||
d_perm_buf.GetDeviceBuffer(),
|
||||
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
|
||||
o_buf.GetDeviceBuffer(),
|
||||
sorted_token_ids_buf.GetDeviceBuffer(),
|
||||
sorted_weight_buf.GetDeviceBuffer(),
|
||||
sorted_expert_ids_buf.GetDeviceBuffer(),
|
||||
num_sorted_tiles_buf.GetDeviceBuffer(),
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride};
|
||||
|
||||
float ave_time = fused_moegemm(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
if(ave_time < 0)
|
||||
{
|
||||
std::cout << " not supported!" << std::endl << std::flush;
|
||||
return false;
|
||||
}
|
||||
|
||||
// float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, "
|
||||
<< cal_tbps(ave_time) << " TB/s" << std::flush;
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
|
||||
a_host,
|
||||
g_host,
|
||||
d_host,
|
||||
sa_host,
|
||||
sg_host,
|
||||
sd_host,
|
||||
sy_host,
|
||||
o_host,
|
||||
sorted_token_ids_host,
|
||||
sorted_weight_host,
|
||||
sorted_expert_ids_host,
|
||||
num_sorted_tiles_host,
|
||||
topk_ids_host,
|
||||
block_m,
|
||||
tokens,
|
||||
experts,
|
||||
hidden_size,
|
||||
shared_intermediate_size_0,
|
||||
topk,
|
||||
gate_only);
|
||||
|
||||
auto o_dev = o_buf.ToHost<ODataType>();
|
||||
// o_dev.savetxt("gpu-out.txt", "float");
|
||||
auto [rtol, atol] = get_elimit<ADataType>();
|
||||
pass &= ck_tile::check_err(
|
||||
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
|
||||
}
|
||||
std::cout << std::flush << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_w = arg_parser.get_str("prec_w");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_st = arg_parser.get_str("prec_st");
|
||||
std::string prec_sw = arg_parser.get_str("prec_sw");
|
||||
std::string prec_sq = arg_parser.get_str("prec_sq");
|
||||
std::string prec_kw = arg_parser.get_str("prec_kw");
|
||||
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
|
||||
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
|
||||
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
|
||||
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
|
||||
|
||||
// no dynamic quant case
|
||||
if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32")
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float>(
|
||||
arg_parser)
|
||||
? 0
|
||||
: -2;
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_w == "fp16" && prec_o == "fp16" && prec_kw == "fp32")
|
||||
{
|
||||
return run<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float>(
|
||||
arg_parser)
|
||||
? 0
|
||||
: -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
BIN
example/ck_tile/15_fused_moe/misc/moe-0.png
Normal file
BIN
example/ck_tile/15_fused_moe/misc/moe-0.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 75 KiB |
BIN
example/ck_tile/15_fused_moe/misc/moe-1.png
Normal file
BIN
example/ck_tile/15_fused_moe/misc/moe-1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 90 KiB |
BIN
example/ck_tile/15_fused_moe/misc/moe-2.png
Normal file
BIN
example/ck_tile/15_fused_moe/misc/moe-2.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 124 KiB |
BIN
example/ck_tile/15_fused_moe/misc/moe-3.png
Normal file
BIN
example/ck_tile/15_fused_moe/misc/moe-3.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 18 KiB |
@@ -14,3 +14,5 @@ add_subdirectory(11_add_rmsnorm2d_rdquant)
|
||||
add_subdirectory(12_smoothquant)
|
||||
add_subdirectory(13_moe_sorting)
|
||||
add_subdirectory(14_moe_smoothquant)
|
||||
add_subdirectory(15_fused_moe)
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_linear.hpp"
|
||||
#include "ck_tile/core/tensor/tile_window_utils.hpp"
|
||||
#include "ck_tile/core/tensor/update_tile.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
@@ -62,6 +63,7 @@
|
||||
#include "ck_tile/core/utility/philox_rand.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/utility/reduce_operator.hpp"
|
||||
#include "ck_tile/core/utility/static_counter.hpp"
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
@@ -621,6 +621,65 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void lds_load_fence(index_t cnt = 0)
|
||||
{
|
||||
asm volatile("s_waitcnt lgkmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
template <typename scalar_type, index_t N, bool pre_nop = false>
|
||||
struct buffer_atomic_add_if;
|
||||
|
||||
template <bool pre_nop>
|
||||
struct buffer_atomic_add_if<bf16_t, 2, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
:
|
||||
: "v"(v_offset),
|
||||
"v"(bit_cast<mbuf_t>(value)),
|
||||
"s"(res.xy),
|
||||
"n"(i_offset),
|
||||
"v"(flag),
|
||||
"s"(save_exec)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_type, index_t N, bool pre_nop = false>
|
||||
struct buffer_atomic_add;
|
||||
|
||||
template <bool pre_nop>
|
||||
struct buffer_atomic_add<bf16_t, 2, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag = 1*/)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = float;
|
||||
asm volatile("global_atomic_pk_add_bf16 %0, %1, %2 offset:%3"
|
||||
:
|
||||
: "v"(v_offset), "v"(bit_cast<mbuf_t>(value)), "s"(res.xy), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
// below type indicate the data type used for buffer load inline asm
|
||||
// clang-format off
|
||||
@@ -810,6 +869,11 @@ CK_TILE_DEVICE void buffer_store_fence(index_t cnt = 0)
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto async_load_fence_raw(index_t cnt = 0)
|
||||
{
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
// buffer load i8
|
||||
CK_TILE_DEVICE_EXTERN int8_t
|
||||
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
|
||||
@@ -2378,6 +2442,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void amd_buffer_atomic_add_raw(const thread_buffer<T, N>& src_thread_data,
|
||||
T* p_dst_wave,
|
||||
const index_t dst_thread_element_offset,
|
||||
const index_t dst_linear_element_offset,
|
||||
const bool dst_thread_element_valid,
|
||||
const index_t dst_element_space_size,
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
const int32x4_t dst_wave_buffer_resource =
|
||||
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
|
||||
|
||||
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
|
||||
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
|
||||
|
||||
if constexpr(oob_conditional_check)
|
||||
{
|
||||
buffer_atomic_add_if<T, N, pre_nop>{}(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
0,
|
||||
dst_linear_addr_offset,
|
||||
dst_thread_element_valid);
|
||||
}
|
||||
else
|
||||
{
|
||||
buffer_atomic_add<T, N, pre_nop>{}(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
0,
|
||||
dst_linear_addr_offset,
|
||||
1);
|
||||
}
|
||||
}
|
||||
|
||||
// buffer_atomic_max requires:
|
||||
// 1) p_dst_wave must point to global memory
|
||||
// 2) p_dst_wave must be a wavewise pointer.
|
||||
|
||||
@@ -73,6 +73,24 @@ CK_TILE_DEVICE void block_sync_lds()
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
asm volatile("s_wait_loadcnt %0 \n"
|
||||
"s_barrier_signal -1 \n"
|
||||
"s_barrier_wait -1"
|
||||
:
|
||||
: "n"(cnt)
|
||||
: "memory");
|
||||
#else
|
||||
asm volatile("s_waitcnt vmcnt(%0) \n"
|
||||
"s_barrier"
|
||||
:
|
||||
: "n"(cnt)
|
||||
: "memory");
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void block_sync_lds_direct_load()
|
||||
{
|
||||
asm volatile("\
|
||||
|
||||
@@ -102,4 +102,28 @@ CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE auto flag_to_exec(const T& v_flag)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
// per-thread v_flag store into 2x sgpr
|
||||
uint32x2_t exec_flag;
|
||||
asm volatile("v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1"
|
||||
: [s_exec_flag] "=s"(exec_flag)
|
||||
: [v_flag] "v"(v_flag));
|
||||
return exec_flag;
|
||||
}
|
||||
|
||||
template <typename X, typename Y>
|
||||
CK_TILE_DEVICE auto cmp_lt_to_exec(const X& x, const Y& y)
|
||||
{
|
||||
static_assert(sizeof(X) == 4 && sizeof(Y) == 4);
|
||||
// per-thread cmp store into 2x sgpr
|
||||
uint32x2_t exec_flag;
|
||||
asm volatile("v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]"
|
||||
: [s_exec_flag] "=s"(exec_flag)
|
||||
: [v_x] "v"(x), [v_y] "v"(y));
|
||||
return exec_flag;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -437,34 +437,74 @@ struct buffer_view<address_space_enum::global,
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
CK_TILE_DEVICE void update(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
if constexpr(Op == memory_operation_enum::set)
|
||||
{
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x);
|
||||
this->template set<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == memory_operation_enum::atomic_add)
|
||||
{
|
||||
this->template atomic_add<X>(i, linear_offset, is_valid_element, x);
|
||||
this->template atomic_add<X, oob_conditional_check>(
|
||||
i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == memory_operation_enum::atomic_max)
|
||||
{
|
||||
this->template atomic_max<X>(i, linear_offset, is_valid_element, x);
|
||||
this->template atomic_max<X, oob_conditional_check>(
|
||||
i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
// FIXME: remove memory_operation_enum::add
|
||||
else if constexpr(Op == memory_operation_enum::add)
|
||||
{
|
||||
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
|
||||
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
|
||||
auto tmp =
|
||||
this->template get<X, oob_conditional_check>(i, linear_offset, is_valid_element);
|
||||
this->template set<X, oob_conditional_check>(
|
||||
i, linear_offset, is_valid_element, x + tmp);
|
||||
// tmp += x;
|
||||
// this->template set<X>(i, is_valid_element, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <memory_operation_enum Op,
|
||||
typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void update_raw(index_t i,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
if constexpr(Op == memory_operation_enum::set)
|
||||
{
|
||||
this->template set_raw<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == memory_operation_enum::atomic_add)
|
||||
{
|
||||
this->template atomic_add_raw<X, oob_conditional_check, pre_nop>(
|
||||
i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
else if constexpr(Op == memory_operation_enum::atomic_max)
|
||||
{
|
||||
// this->template atomic_max_raw<X>(i, linear_offset, is_valid_element, x);
|
||||
}
|
||||
}
|
||||
|
||||
// i is offset of T, not X. i should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
@@ -533,6 +573,7 @@ struct buffer_view<address_space_enum::global,
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
@@ -585,6 +626,39 @@ struct buffer_view<address_space_enum::global,
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
bool>::type = false>
|
||||
CK_TILE_DEVICE void
|
||||
atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
|
||||
{
|
||||
// using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
|
||||
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
|
||||
|
||||
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
|
||||
|
||||
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
|
||||
"wrong! X should contain multiple T");
|
||||
|
||||
static_assert(get_address_space() == address_space_enum::global, "only support global mem");
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
amd_buffer_atomic_add_raw<remove_cvref_t<T>,
|
||||
t_per_x,
|
||||
Coherence,
|
||||
oob_conditional_check,
|
||||
pre_nop>(
|
||||
x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
typename std::enable_if<
|
||||
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
|
||||
|
||||
@@ -22,28 +22,32 @@ template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
|
||||
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
|
||||
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor_,
|
||||
@@ -51,15 +55,35 @@ template <typename DistributedTensor_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(dst_tile, bool_constant<oob_conditional_check>{});
|
||||
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
|
||||
const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -76,6 +100,7 @@ template <typename T,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
@@ -83,11 +108,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.load_raw(
|
||||
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
@@ -95,6 +121,7 @@ template <typename T,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
@@ -102,11 +129,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.load_raw(
|
||||
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
@@ -114,6 +142,7 @@ template <typename LdsTileWindow_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto
|
||||
@@ -122,11 +151,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
return tile_window.async_load_raw(
|
||||
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
return tile_window.async_load_raw(lds_tile,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
@@ -134,6 +166,7 @@ template <typename LdsTileWindow_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
@@ -141,11 +174,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
return tile_window.async_load_raw(
|
||||
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
|
||||
return tile_window.async_load_raw(lds_tile,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
|
||||
|
||||
@@ -201,4 +201,30 @@ CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number
|
||||
return unpacks;
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
// check if 2 static_distributed_tensor has same data type and size of element
|
||||
// but only difference in distribution
|
||||
template <typename X, typename Y>
|
||||
struct is_similiar_distributed_tensor
|
||||
{
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <typename TypeX, typename DistX, typename TypeY, typename DistY>
|
||||
struct is_similiar_distributed_tensor<static_distributed_tensor<TypeX, DistX>,
|
||||
static_distributed_tensor<TypeY, DistY>>
|
||||
{
|
||||
using Tx = static_distributed_tensor<TypeX, DistX>;
|
||||
using Ty = static_distributed_tensor<TypeY, DistY>;
|
||||
static constexpr bool value = std::is_same_v<typename Tx::DataType, typename Ty::DataType> &&
|
||||
Tx::get_thread_buffer_size() == Ty::get_thread_buffer_size();
|
||||
};
|
||||
|
||||
template <typename X, typename Y>
|
||||
inline constexpr bool is_similiar_distributed_tensor_v =
|
||||
is_similiar_distributed_tensor<X, Y>::value;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -333,6 +333,48 @@ struct tensor_view
|
||||
coord.get_offset(), linear_offset, is_valid_element, x);
|
||||
}
|
||||
|
||||
// X is vector of DataType.
|
||||
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements_raw(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
|
||||
coord.get_offset(),
|
||||
linear_offset,
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
x);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false,
|
||||
typename std::enable_if<
|
||||
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
|
||||
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
|
||||
bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr void
|
||||
update_vectorized_elements_raw(const TensorCoord& coord,
|
||||
index_t linear_offset,
|
||||
bool is_valid_element,
|
||||
const X& x,
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
|
||||
coord.get_offset(), linear_offset, is_valid_element, x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE void print() const
|
||||
{
|
||||
printf("tensor_view{");
|
||||
|
||||
@@ -292,12 +292,15 @@ struct tile_window_with_static_distribution
|
||||
{
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
|
||||
load(dst_tensor, bool_constant<oob_conditional_check>{});
|
||||
load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename DistributedTensor, bool oob_conditional_check = true>
|
||||
template <typename DistributedTensor,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
@@ -785,6 +788,73 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true, bool pre_nop>
|
||||
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
0,
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// move thread's botom tensor coordiante
|
||||
// [x0', x1', ... ] ==> [offset]
|
||||
// also move window-origin
|
||||
|
||||
@@ -432,23 +432,38 @@ struct tile_window_linear
|
||||
CK_TILE_DEVICE static constexpr index_t get_bottom_linear_offset(number<i_access>)
|
||||
{
|
||||
constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{});
|
||||
// since this is linear offset, we assum bottom X tensor is always linear
|
||||
constexpr index_t linear_offset = [&]() {
|
||||
constexpr auto x_idx_ = linear_coord;
|
||||
constexpr auto x_len_ = TileDstr{}.get_lengths();
|
||||
static_assert(x_idx_.size() == x_len_.size());
|
||||
constexpr index_t x_dims_ = x_idx_.size();
|
||||
index_t cu_stride_ = 1;
|
||||
index_t cu_offset_ = 0;
|
||||
static_for<0, x_dims_, 1>{}([&](auto i_) {
|
||||
auto r_i_ = number<x_dims_ - i_ - 1>{};
|
||||
cu_offset_ += x_idx_[r_i_] * cu_stride_;
|
||||
cu_stride_ *= x_len_[r_i_];
|
||||
});
|
||||
return cu_offset_;
|
||||
}();
|
||||
|
||||
return linear_offset;
|
||||
constexpr auto is_pure_linear_tensor =
|
||||
reduce_on_sequence(LinearBottomDims{}, multiplies{}, number<1>{});
|
||||
if constexpr(is_pure_linear_tensor)
|
||||
{
|
||||
// this case usually is a LDS window, everything is known at compile tile.
|
||||
// we directly use BottomTensorView transform to compute the offset, in case padding
|
||||
auto bottom_tensor_coord =
|
||||
make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
|
||||
return bottom_tensor_coord.get_offset();
|
||||
}
|
||||
else
|
||||
{
|
||||
// this case usually is a global window, where last dim can be linear
|
||||
// we hack here, that use the original TileDstr to compute the linear offset
|
||||
// ... hoping that there is no extra padding between other dims, which make sense
|
||||
// since that would introduce runtime length (so can't use linear offset)
|
||||
constexpr index_t linear_offset = [&]() {
|
||||
constexpr auto x_idx_ = linear_coord;
|
||||
constexpr auto x_len_ = TileDstr{}.get_lengths();
|
||||
static_assert(x_idx_.size() == x_len_.size());
|
||||
constexpr index_t x_dims_ = x_idx_.size();
|
||||
index_t cu_stride_ = 1;
|
||||
index_t cu_offset_ = 0;
|
||||
static_for<0, x_dims_, 1>{}([&](auto i_) {
|
||||
auto r_i_ = number<x_dims_ - i_ - 1>{};
|
||||
cu_offset_ += x_idx_[r_i_] * cu_stride_;
|
||||
cu_stride_ *= x_len_[r_i_];
|
||||
});
|
||||
return cu_offset_;
|
||||
}();
|
||||
return linear_offset;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; }
|
||||
@@ -509,6 +524,64 @@ struct tile_window_linear
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename DstTile, index_t i_access = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(DstTile& dst_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
// auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
|
||||
|
||||
auto issue = [&](auto i_access_) {
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
|
||||
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
|
||||
auto bottom_tensor_flag = cached_flags_[IAccess];
|
||||
|
||||
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
linear_offset,
|
||||
bottom_tensor_flag,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
#if 1
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
|
||||
// write into distributed tensor
|
||||
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value.template get_as<DataType>()[j];
|
||||
});
|
||||
#else
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
static_assert(d % traits::ScalarPerVector == 0);
|
||||
|
||||
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
|
||||
number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
|
||||
#endif
|
||||
};
|
||||
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename DstTile,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
@@ -849,6 +922,58 @@ struct tile_window_linear
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
}
|
||||
|
||||
template <index_t i_access = -1, bool oob_conditional_check = true, bool pre_nop = false>
|
||||
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {}) const
|
||||
{
|
||||
|
||||
using vector_t = typename traits::vector_t;
|
||||
using SFC_Ys = typename traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
auto issue = [&](auto i_access_) {
|
||||
constexpr auto IAccess = number<i_access_>{};
|
||||
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
|
||||
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
|
||||
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
|
||||
auto bottom_tensor_flag = cached_flags_[IAccess];
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
|
||||
|
||||
// read from distributed tensor
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
|
||||
|
||||
vec_value.template get_as<DataType>()(j) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
linear_offset,
|
||||
bottom_tensor_flag,
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
};
|
||||
|
||||
WINDOW_DISPATCH_ISSUE();
|
||||
}
|
||||
|
||||
// move thread's botom tensor coordiante
|
||||
// [x0', x1', ... ] ==> [offset]
|
||||
// also move window-origin
|
||||
|
||||
54
include/ck_tile/core/tensor/tile_window_utils.hpp
Normal file
54
include/ck_tile/core/tensor/tile_window_utils.hpp
Normal file
@@ -0,0 +1,54 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
#pragma once
|
||||
namespace ck_tile {
|
||||
|
||||
// input a lds store tile, extract some information from it
|
||||
// used to set m0 value for gfx9 serious
|
||||
template <typename LdsTileWindow_>
|
||||
CK_TILE_DEVICE auto get_async_store_smem_info(LdsTileWindow_&& lds_tile)
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
|
||||
// issues * warps * lanes
|
||||
static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
|
||||
|
||||
const index_t size_per_buf =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<0>{}, number<0>{})) *
|
||||
sizeof(LdsDataType);
|
||||
|
||||
const index_t size_per_wave =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<0>{}, number<1>{}, number<0>{})) *
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t size_per_issue =
|
||||
lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(number<1>{}, number<0>{}, number<0>{})) *
|
||||
sizeof(LdsDataType) -
|
||||
size_per_buf;
|
||||
|
||||
const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
|
||||
|
||||
return make_tuple(m0_init_value, size_per_issue);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -41,15 +41,65 @@ template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_>
|
||||
typename DataType_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void
|
||||
update_tile(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
tile_window.update(dstr_tensor);
|
||||
tile_window.update(dstr_tensor, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename DataType_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE void
|
||||
update_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.update_raw(dstr_tensor,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
typename DataType_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto update_tile_raw(
|
||||
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
|
||||
tile_window,
|
||||
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
tile_window.update_raw(dstr_tensor,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
116
include/ck_tile/core/utility/static_counter.hpp
Normal file
116
include/ck_tile/core/utility/static_counter.hpp
Normal file
@@ -0,0 +1,116 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Context, index_t Start = 0, index_t Step = 1>
|
||||
struct static_counter
|
||||
{
|
||||
public:
|
||||
template <typename Unique>
|
||||
static constexpr index_t next()
|
||||
{
|
||||
return next<Unique>(0) * Step + Start;
|
||||
}
|
||||
|
||||
template <unsigned long long>
|
||||
static constexpr index_t next()
|
||||
{
|
||||
struct Unique
|
||||
{
|
||||
};
|
||||
return next<Unique>(0) * Step + Start;
|
||||
}
|
||||
|
||||
template <typename Unique>
|
||||
static constexpr index_t current()
|
||||
{
|
||||
return current<Unique>(0) * Step + Start;
|
||||
}
|
||||
|
||||
template <unsigned long long>
|
||||
static constexpr index_t current()
|
||||
{
|
||||
struct Unique
|
||||
{
|
||||
};
|
||||
return current<Unique>(0) * Step + Start;
|
||||
}
|
||||
|
||||
private:
|
||||
template <index_t I>
|
||||
struct slot
|
||||
{
|
||||
_Pragma("GCC diagnostic push");
|
||||
_Pragma("GCC diagnostic ignored \"-Wundefined-internal\"");
|
||||
friend constexpr bool slot_allocated(slot<I>);
|
||||
_Pragma("GCC diagnostic pop");
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
struct allocate_slot
|
||||
{
|
||||
friend constexpr bool slot_allocated(slot<I>) { return true; }
|
||||
enum
|
||||
{
|
||||
value = I
|
||||
};
|
||||
};
|
||||
|
||||
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
|
||||
// the overload set...
|
||||
template <typename Unique, index_t I = 0, bool = slot_allocated(slot<I>())>
|
||||
static constexpr index_t next(index_t)
|
||||
{
|
||||
return next<Unique, I + 1>(0);
|
||||
}
|
||||
|
||||
// ...And this function will be used, instead, which will define slot_allocated(slot<I>) via
|
||||
// allocate_slot<I>.
|
||||
template <typename Unique, index_t I = 0>
|
||||
static constexpr index_t next(double)
|
||||
{
|
||||
return allocate_slot<I>::value;
|
||||
}
|
||||
|
||||
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
|
||||
// the overload set...
|
||||
template <typename Unique, index_t I = Start, bool = slot_allocated(slot<I>())>
|
||||
static constexpr index_t current(index_t)
|
||||
{
|
||||
return current<Unique, I + 1>(0);
|
||||
}
|
||||
|
||||
// ...And this function will be used, instead, which will return the current counter, or assert
|
||||
// in case next() hasn't been called yet.
|
||||
template <typename Unique, index_t I = Start>
|
||||
static constexpr index_t current(double)
|
||||
{
|
||||
static_assert(I != 0, "You must invoke next() first");
|
||||
|
||||
return I - 1;
|
||||
}
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
template <int I>
|
||||
struct static_counter_uniq_;
|
||||
}
|
||||
|
||||
#define MAKE_SC() \
|
||||
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>> {}
|
||||
#define MAKE_SC_WITH(start_, step_) \
|
||||
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_> {}
|
||||
#define NEXT_SC(c_) c_.next<__COUNTER__>()
|
||||
#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>()
|
||||
|
||||
// Usage:
|
||||
// constexpr auto c = MAKE_SC()
|
||||
// NEXT_SC(c) // -> constexpr 0
|
||||
// NEXT_SC(c) // -> constexpr 1
|
||||
// NEXT_SC(c) // -> constexpr 2
|
||||
} // namespace ck_tile
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ck_tile/host/fill.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/joinable_thread.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/ranges.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
|
||||
@@ -20,6 +21,7 @@
|
||||
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
|
||||
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_fused_moe.hpp"
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "ck_tile/host/reference/reference_im2col.hpp"
|
||||
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <stdint.h>
|
||||
#include <stdexcept>
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename T>
|
||||
@@ -36,6 +37,19 @@ struct DeviceMem
|
||||
mpDeviceBuf = nullptr;
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
DeviceMem(const HostTensor<T>& t) : mMemSize(t.get_element_space_size_in_bytes())
|
||||
{
|
||||
if(mMemSize != 0)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
}
|
||||
else
|
||||
{
|
||||
mpDeviceBuf = nullptr;
|
||||
}
|
||||
ToDevice(t.data());
|
||||
}
|
||||
void Realloc(std::size_t mem_size)
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
@@ -92,6 +106,27 @@ struct DeviceMem
|
||||
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
}
|
||||
|
||||
// construct a host tensor with type T
|
||||
template <typename T>
|
||||
HostTensor<T> ToHost(std::size_t cpySize)
|
||||
{
|
||||
// TODO: host tensor could be slightly larger than the device tensor
|
||||
// we just copy all data from GPU buffer
|
||||
std::size_t host_elements = (cpySize + sizeof(T) - 1) / sizeof(T);
|
||||
HostTensor<T> h_({host_elements});
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
return h_;
|
||||
}
|
||||
template <typename T>
|
||||
HostTensor<T> ToHost()
|
||||
{
|
||||
return ToHost<T>(mMemSize);
|
||||
}
|
||||
|
||||
void SetZero() const
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/joinable_thread.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -22,13 +23,44 @@ struct FillUniformDistribution
|
||||
float a_{-5.f};
|
||||
float b_{5.f};
|
||||
std::optional<uint32_t> seed_{11939};
|
||||
// ATTENTION: threaded does not guarantee the distribution between thread
|
||||
bool threaded = false;
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
|
||||
std::uniform_real_distribution<float> dis(a_, b_);
|
||||
std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
|
||||
if(threaded)
|
||||
{
|
||||
uint32_t num_thread = std::thread::hardware_concurrency();
|
||||
auto total = static_cast<std::size_t>(std::distance(first, last));
|
||||
auto work_per_thread = static_cast<std::size_t>((total + num_thread - 1) / num_thread);
|
||||
|
||||
std::vector<joinable_thread> threads(num_thread);
|
||||
for(std::size_t it = 0; it < num_thread; ++it)
|
||||
{
|
||||
std::size_t iw_begin = it * work_per_thread;
|
||||
std::size_t iw_end = std::min((it + 1) * work_per_thread, total);
|
||||
auto thread_f = [this, total, iw_begin, iw_end, &first] {
|
||||
if(iw_begin > total || iw_end > total)
|
||||
return;
|
||||
// need to make each thread unique, add an offset to current seed
|
||||
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
|
||||
: std::random_device{}());
|
||||
std::uniform_real_distribution<float> dis(a_, b_);
|
||||
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
|
||||
return ck_tile::type_convert<T>(dis(gen));
|
||||
});
|
||||
};
|
||||
threads[it] = joinable_thread(thread_f);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
|
||||
std::uniform_real_distribution<float> dis(a_, b_);
|
||||
std::generate(
|
||||
first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
@@ -115,13 +147,44 @@ struct FillNormalDistribution
|
||||
float mean_{0.f};
|
||||
float variance_{1.f};
|
||||
std::optional<uint32_t> seed_{11939};
|
||||
// ATTENTION: threaded does not guarantee the distribution between thread
|
||||
bool threaded = false;
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
|
||||
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
|
||||
std::generate(first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
|
||||
if(threaded)
|
||||
{
|
||||
uint32_t num_thread = std::thread::hardware_concurrency();
|
||||
auto total = static_cast<std::size_t>(std::distance(first, last));
|
||||
auto work_per_thread = static_cast<std::size_t>((total + num_thread - 1) / num_thread);
|
||||
|
||||
std::vector<joinable_thread> threads(num_thread);
|
||||
for(std::size_t it = 0; it < num_thread; ++it)
|
||||
{
|
||||
std::size_t iw_begin = it * work_per_thread;
|
||||
std::size_t iw_end = std::min((it + 1) * work_per_thread, total);
|
||||
auto thread_f = [this, total, iw_begin, iw_end, &first] {
|
||||
if(iw_begin > total || iw_end > total)
|
||||
return;
|
||||
// need to make each thread unique, add an offset to current seed
|
||||
std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin)
|
||||
: std::random_device{}());
|
||||
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
|
||||
std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() {
|
||||
return ck_tile::type_convert<T>(dis(gen));
|
||||
});
|
||||
};
|
||||
threads[it] = joinable_thread(thread_f);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}());
|
||||
std::normal_distribution<float> dis(mean_, std::sqrt(variance_));
|
||||
std::generate(
|
||||
first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
@@ -235,6 +298,44 @@ struct FillMonotonicSeq
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, bool IsAscending = true>
|
||||
struct FillStepRange
|
||||
{
|
||||
float start_value_{0};
|
||||
float end_value_{3};
|
||||
float step_{1};
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::generate(first, last, [=, n = start_value_]() mutable {
|
||||
auto tmp = n;
|
||||
n += step_;
|
||||
if constexpr(IsAscending)
|
||||
{
|
||||
if(n > end_value_)
|
||||
n = start_value_;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(n < end_value_)
|
||||
n = start_value_;
|
||||
}
|
||||
|
||||
return type_convert<T>(tmp);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
auto operator()(ForwardRange&& range) const -> std::void_t<
|
||||
decltype(std::declval<const FillStepRange&>()(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
{
|
||||
(*this)(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range)));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct FillConstant
|
||||
{
|
||||
|
||||
@@ -8,12 +8,13 @@
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
#include <numeric>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <fstream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/joinable_thread.hpp"
|
||||
#include "ck_tile/host/ranges.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -213,23 +214,6 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
|
||||
return HostTensorDescriptor(new_lengths, new_strides);
|
||||
}
|
||||
|
||||
struct joinable_thread : std::thread
|
||||
{
|
||||
template <typename... Xs>
|
||||
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
|
||||
{
|
||||
}
|
||||
|
||||
joinable_thread(joinable_thread&&) = default;
|
||||
joinable_thread& operator=(joinable_thread&&) = default;
|
||||
|
||||
~joinable_thread()
|
||||
{
|
||||
if(this->joinable())
|
||||
this->join();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename F, typename... Xs>
|
||||
struct ParallelTensorFunctor
|
||||
{
|
||||
@@ -590,6 +574,107 @@ struct HostTensor
|
||||
size() * FromSize / ToSize};
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const HostTensor<T>& t)
|
||||
{
|
||||
os << t.mDesc;
|
||||
os << "[";
|
||||
for(typename Data::size_type idx = 0; idx < t.mData.size(); ++idx)
|
||||
{
|
||||
if(0 < idx)
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t>)
|
||||
{
|
||||
os << type_convert<float>(t.mData[idx]) << " #### ";
|
||||
}
|
||||
else
|
||||
{
|
||||
os << t.mData[idx];
|
||||
}
|
||||
}
|
||||
os << "]";
|
||||
return os;
|
||||
}
|
||||
|
||||
// read data from a file, as dtype
|
||||
// the file could dumped from torch as (targeting tensor is t here)
|
||||
// numpy.savetxt("f.txt", t.view(-1).numpy())
|
||||
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy()) # from cuda to cpu to save
|
||||
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy(), fmt="%d") # save as int
|
||||
// will output f.txt, each line is a value
|
||||
// dtype=float or int, internally will cast to real type
|
||||
void loadtxt(std::string file_name, std::string dtype = "float")
|
||||
{
|
||||
std::ifstream file(file_name);
|
||||
|
||||
if(file.is_open())
|
||||
{
|
||||
std::string line;
|
||||
|
||||
index_t cnt = 0;
|
||||
while(std::getline(file, line))
|
||||
{
|
||||
if(cnt >= static_cast<index_t>(mData.size()))
|
||||
{
|
||||
throw std::runtime_error(std::string("data read from file:") + file_name +
|
||||
" is too big");
|
||||
}
|
||||
|
||||
if(dtype == "float")
|
||||
{
|
||||
mData[cnt] = type_convert<T>(std::stof(line));
|
||||
}
|
||||
else if(dtype == "int" || dtype == "int32")
|
||||
{
|
||||
mData[cnt] = type_convert<T>(std::stoi(line));
|
||||
}
|
||||
cnt++;
|
||||
}
|
||||
file.close();
|
||||
if(cnt < static_cast<index_t>(mData.size()))
|
||||
{
|
||||
std::cerr << "Warning! reading from file:" << file_name
|
||||
<< ", does not match the size of this tensor" << std::endl;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Print an error message to the standard error
|
||||
// stream if the file cannot be opened.
|
||||
throw std::runtime_error(std::string("unable to open file:") + file_name);
|
||||
}
|
||||
}
|
||||
|
||||
// can save to a txt file and read from torch as:
|
||||
// torch.from_numpy(np.loadtxt('f.txt', dtype=np.int32/np.float32...)).view([...]).contiguous()
|
||||
void savetxt(std::string file_name, std::string dtype = "float")
|
||||
{
|
||||
std::ofstream file(file_name);
|
||||
|
||||
if(file.is_open())
|
||||
{
|
||||
for(auto& itm : mData)
|
||||
{
|
||||
if(dtype == "float")
|
||||
file << type_convert<float>(itm) << std::endl;
|
||||
else if(dtype == "int")
|
||||
file << type_convert<int>(itm) << std::endl;
|
||||
else
|
||||
// TODO: we didn't implement operator<< for all custom
|
||||
// data types, here fall back to float in case compile error
|
||||
file << type_convert<float>(itm) << std::endl;
|
||||
}
|
||||
file.close();
|
||||
}
|
||||
else
|
||||
{
|
||||
// Print an error message to the standard error
|
||||
// stream if the file cannot be opened.
|
||||
throw std::runtime_error(std::string("unable to open file:") + file_name);
|
||||
}
|
||||
}
|
||||
|
||||
Descriptor mDesc;
|
||||
Data mData;
|
||||
};
|
||||
|
||||
27
include/ck_tile/host/joinable_thread.hpp
Normal file
27
include/ck_tile/host/joinable_thread.hpp
Normal file
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct joinable_thread : std::thread
|
||||
{
|
||||
template <typename... Xs>
|
||||
joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
|
||||
{
|
||||
}
|
||||
|
||||
joinable_thread(joinable_thread&&) = default;
|
||||
joinable_thread& operator=(joinable_thread&&) = default;
|
||||
|
||||
~joinable_thread()
|
||||
{
|
||||
if(this->joinable())
|
||||
this->join();
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
196
include/ck_tile/host/reference/reference_fused_moe.hpp
Normal file
196
include/ck_tile/host/reference/reference_fused_moe.hpp
Normal file
@@ -0,0 +1,196 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// [indexing implementation-1]
|
||||
// using M_a as constexpr block_size to partition all tokens into different slices
|
||||
// each slice map to one expert, and one expert can have multiple slices
|
||||
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
|
||||
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
|
||||
// tok-0 tok-1 tok-2 tok-3 tok-4
|
||||
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
|
||||
// number)
|
||||
//
|
||||
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
|
||||
// * this could be larger than actual, since actual tokens are on GPU
|
||||
//
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
|
||||
// 0, 1, 2, 5]
|
||||
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4
|
||||
// -|- exp-5 -|
|
||||
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
|
||||
// c, f, i, o]
|
||||
//
|
||||
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
|
||||
//
|
||||
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
|
||||
// * length is (max_num_tokens_padded + block_size - 1) / block_size
|
||||
///
|
||||
// num_tokens_post_padded_ptr : [28]
|
||||
// num_sorted_tiles_ptr : [7]
|
||||
|
||||
template <typename AccDataType, // you only need to explcitly set this one
|
||||
typename Activation, // ck_tile::element_wise::Gelu
|
||||
typename ADataType,
|
||||
typename GDataType,
|
||||
typename DDataType,
|
||||
typename ODataType,
|
||||
typename AScaleDataType,
|
||||
typename GScaleDataType,
|
||||
typename DScaleDataType,
|
||||
typename YSmoothScaleDataType,
|
||||
typename TopkWeightDataType,
|
||||
typename IndexDataType>
|
||||
void reference_fused_moe(
|
||||
const ck_tile::HostTensor<ADataType>& a_host, // [tokens, hidden_size]
|
||||
const ck_tile::HostTensor<GDataType>& g_host, // [experts, interme_size_0, hidden_size]
|
||||
const ck_tile::HostTensor<DDataType>& d_host, // [experts, hidden_size, interme_size_1]
|
||||
const ck_tile::HostTensor<AScaleDataType>& sa_host, // [tokens, 1],
|
||||
const ck_tile::HostTensor<GScaleDataType>& sg_host, // [experts, 1, interme_size_0]
|
||||
const ck_tile::HostTensor<DScaleDataType>& sd_host, // [experts, 1, hidden_size],
|
||||
const ck_tile::HostTensor<YSmoothScaleDataType>& sy_host, // [experts, 1, interme_size_0]
|
||||
ck_tile::HostTensor<ODataType>& o_host, // [tokens, hidden_size]
|
||||
const ck_tile::HostTensor<IndexDataType>& sorted_token_ids_host, // [max_num_tokens_padded]
|
||||
const ck_tile::HostTensor<TopkWeightDataType>& sorted_weight_host, // [max_num_tokens_padded]
|
||||
const ck_tile::HostTensor<IndexDataType>&
|
||||
sorted_expert_ids_host, // [(max_num_tokens_padded + block_size - 1) / block_size]
|
||||
const ck_tile::HostTensor<IndexDataType>& num_sorted_tiles_host, // [1]
|
||||
|
||||
const ck_tile::HostTensor<IndexDataType>&
|
||||
token_ids_host, // [tokens, topk] --> ugly!!! remove in the future
|
||||
|
||||
ck_tile::index_t block_m,
|
||||
ck_tile::index_t tokens,
|
||||
ck_tile::index_t experts,
|
||||
ck_tile::index_t hidden_size,
|
||||
ck_tile::index_t intermediate_size, // this size is for gate/up
|
||||
ck_tile::index_t topk,
|
||||
ck_tile::index_t gate_only)
|
||||
{
|
||||
assert(sorted_token_ids_host.get_num_of_dimension() == 1);
|
||||
assert(sorted_weight_host.get_num_of_dimension() == 1);
|
||||
assert(sorted_expert_ids_host.get_num_of_dimension() == 1);
|
||||
assert(num_sorted_tiles_host.get_element_size() == 1);
|
||||
ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0] / block_m;
|
||||
ck_tile::index_t intermediate_size_0 = intermediate_size;
|
||||
ck_tile::index_t intermediate_size_1 = intermediate_size / (gate_only ? 1 : 2);
|
||||
|
||||
// TODO: better remove this in the future, or modify the token_id value
|
||||
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
|
||||
for(ck_tile::index_t i_ = 0; i_ < topk; i_++)
|
||||
{
|
||||
if(token_ids_host(token_id_, i_) == expert_id_)
|
||||
return i_;
|
||||
}
|
||||
throw std::runtime_error("not correct token/expert pair\n");
|
||||
return -1; // TODO: not correct!!
|
||||
};
|
||||
|
||||
ck_tile::HostTensor<AccDataType> out_topk_tokens({tokens, topk, hidden_size});
|
||||
|
||||
int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
|
||||
// assert();
|
||||
auto f = [&](auto i_flatten) {
|
||||
ck_tile::index_t i_tile = i_flatten / block_m;
|
||||
if(i_tile >= num_sorted_tiles)
|
||||
return;
|
||||
ck_tile::index_t i_expert = sorted_expert_ids_host.mData[i_tile];
|
||||
ck_tile::index_t i_token = sorted_token_ids_host.mData[i_flatten];
|
||||
if(i_token >= tokens)
|
||||
return;
|
||||
ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
|
||||
auto weight = sorted_weight_host.mData[i_flatten];
|
||||
|
||||
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
|
||||
// first gemm
|
||||
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++)
|
||||
{
|
||||
AccDataType acc = static_cast<AccDataType>(0);
|
||||
for(ck_tile::index_t i_k = 0; i_k < hidden_size; i_k++)
|
||||
{
|
||||
acc += type_convert<AccDataType>(a_host(i_token, i_k)) *
|
||||
type_convert<AccDataType>(g_host(i_expert, i_n, i_k));
|
||||
}
|
||||
acc_0(0, i_n) = acc;
|
||||
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<AccDataType> y({1, intermediate_size_1});
|
||||
if(gate_only)
|
||||
{
|
||||
if(intermediate_size_1 != intermediate_size_0)
|
||||
throw std::runtime_error(
|
||||
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
|
||||
", 1:" + std::to_string(intermediate_size_1));
|
||||
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
|
||||
{
|
||||
Activation{}(y(0, i_n), acc_0(0, i_n));
|
||||
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(intermediate_size_1 * 2 != intermediate_size_0)
|
||||
throw std::runtime_error(
|
||||
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
|
||||
", 1:" + std::to_string(intermediate_size_1));
|
||||
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
|
||||
{
|
||||
AccDataType tmp;
|
||||
Activation{}(tmp, acc_0(0, i_n));
|
||||
y(0, i_n) = tmp * acc_0(0, i_n + intermediate_size_1); // TODO: elementwise mul
|
||||
}
|
||||
}
|
||||
|
||||
// second gemm, loop along gemm-n
|
||||
ck_tile::HostTensor<AccDataType> acc_1({1, hidden_size});
|
||||
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
|
||||
{
|
||||
AccDataType acc = static_cast<AccDataType>(0);
|
||||
for(ck_tile::index_t i_k = 0; i_k < intermediate_size_1; i_k++)
|
||||
{
|
||||
acc += y(0, i_k) * type_convert<AccDataType>(d_host(i_expert, i_n, i_k));
|
||||
}
|
||||
acc_1(0, i_n) = acc * weight; // multiple weight here
|
||||
}
|
||||
|
||||
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
|
||||
{
|
||||
out_topk_tokens(i_token, i_topk, i_n) = acc_1(0, i_n);
|
||||
}
|
||||
};
|
||||
|
||||
// make_ParallelTensorFunctor(f, max_num_tokens_padded)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f, max_num_tokens_padded)(1);
|
||||
|
||||
// reduce
|
||||
auto r = [&](auto i_token) {
|
||||
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
|
||||
{
|
||||
AccDataType acc = type_convert<AccDataType>(0);
|
||||
for(ck_tile::index_t i_topk = 0; i_topk < topk; i_topk++)
|
||||
{
|
||||
acc += out_topk_tokens(i_token, i_topk, i_n);
|
||||
}
|
||||
o_host(i_token, i_n) = type_convert<ODataType>(acc);
|
||||
}
|
||||
};
|
||||
make_ParallelTensorFunctor(r, tokens)(std::thread::hardware_concurrency());
|
||||
|
||||
(void)num_sorted_tiles_host;
|
||||
(void)sa_host;
|
||||
(void)sg_host;
|
||||
(void)sd_host;
|
||||
(void)sy_host;
|
||||
}
|
||||
} // namespace ck_tile
|
||||
@@ -16,7 +16,7 @@ namespace ck_tile {
|
||||
*/
|
||||
template <typename DataType>
|
||||
CK_TILE_HOST void
|
||||
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> dims)
|
||||
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> perm)
|
||||
{
|
||||
const auto x_len = x.mDesc.get_lengths();
|
||||
const auto y_len = y.mDesc.get_lengths();
|
||||
@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
|
||||
std::vector<size_t> tmp(rank, 0);
|
||||
for(index_t i = 0; i < rank; i++)
|
||||
{
|
||||
tmp[dims[i]] = y_coord[i];
|
||||
tmp[perm[i]] = y_coord[i];
|
||||
}
|
||||
return tmp;
|
||||
}();
|
||||
@@ -54,4 +54,23 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
|
||||
|
||||
make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
CK_TILE_HOST auto reference_permute(const HostTensor<DataType>& x, std::vector<index_t> perm)
|
||||
{
|
||||
auto x_shape = x.get_lengths();
|
||||
ck_tile::index_t rank = perm.size();
|
||||
std::vector<ck_tile::index_t> y_shape = [&]() {
|
||||
std::vector<ck_tile::index_t> tmp(rank, 0);
|
||||
for(int i = 0; i < static_cast<int>(rank); i++)
|
||||
{
|
||||
tmp[i] = x_shape[perm[i]];
|
||||
}
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
HostTensor<DataType> y(y_shape);
|
||||
reference_permute(x, y, perm);
|
||||
return y;
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -572,6 +572,105 @@ struct FastGelu
|
||||
}
|
||||
};
|
||||
|
||||
struct FastGeluAsm
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_DEVICE void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
|
||||
const float c1 = -2.0 * 0.035677f;
|
||||
const float c2 = -2.0 * 0.797885f;
|
||||
const float u = x * (c1 * x * x + c2);
|
||||
const float emu = exp(u);
|
||||
y = x / (1.f + emu);
|
||||
}
|
||||
|
||||
// device code, use lower precision "__ocml_exp_f32" and "rcp"
|
||||
template <>
|
||||
CK_TILE_DEVICE void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
const uint32_t c1 = 0xbd92220c; // -2.0 * 0.035677f;
|
||||
const float c2 = -2.0 * 0.797885f;
|
||||
const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v<float>;
|
||||
float tmp;
|
||||
|
||||
asm volatile("v_mul_f32 %[v_tmp], %[v_x], %[v_x] ; x*x\n"
|
||||
"v_fma_f32 %[v_tmp], %[v_tmp], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
|
||||
"v_mul_f32 %[v_tmp], %[v_tmp], %[v_x] ; x*(c1*x*x+c2)\n"
|
||||
"v_mul_f32 %[v_tmp], %[v_tmp], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
|
||||
"v_exp_f32 %[v_tmp], %[v_tmp] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
|
||||
"s_nop 0 ; hazard for exp\n"
|
||||
"v_add_f32 %[v_tmp], %[v_tmp], 1.0 ; emu+1.0f\n"
|
||||
"v_rcp_f32 %[v_tmp], %[v_tmp] ; 1/(emu+1.0f)\n"
|
||||
"s_nop 0 ; hazard for rcp \n"
|
||||
"v_mul_f32 %[v_y], %[v_tmp], %[v_x] ; x * 1/(emu+1f)\n"
|
||||
: [v_y] "=v"(y), [v_tmp] "+v"(tmp)
|
||||
: [v_x] "v"(x), [s_c1] "s"(c1), [v_c2] "v"(c2), [s_log2e] "s"(log2e_)
|
||||
:);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST void operator()<fp32x2_t, fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
|
||||
{
|
||||
const float c1 = -2.0 * 0.035677f;
|
||||
const float c2 = -2.0 * 0.797885f;
|
||||
const float u0 = x.x * (c1 * x.x * x.x + c2);
|
||||
const float emu0 = exp(u0);
|
||||
y.x = x.x / (1.f + emu0);
|
||||
const float u1 = x.y * (c1 * x.y * x.y + c2);
|
||||
const float emu1 = exp(u1);
|
||||
y.y = x.y / (1.f + emu1);
|
||||
}
|
||||
|
||||
// this is packed verion to remove data hazard for trans
|
||||
template <>
|
||||
CK_TILE_DEVICE void operator()<fp32x2_t, fp32x2_t>(fp32x2_t& y, const fp32x2_t& x) const
|
||||
{
|
||||
const uint32_t c1 = 0xbd92220c; // -2.0 * 0.035677f;
|
||||
float c2 = -2.0 * 0.797885f;
|
||||
const uint32_t log2e_ = 0x3fb8aa3b; // log2e_v<float>;
|
||||
float tmp0, tmp1;
|
||||
float y0 = x.x, y1 = x.y;
|
||||
|
||||
asm volatile(
|
||||
"v_mul_f32 %[v_tmp0], %[v_y0], %[v_y0] ; x*x\n"
|
||||
"v_mul_f32 %[v_tmp1], %[v_y1], %[v_y1] ; x*x\n"
|
||||
"v_fma_f32 %[v_tmp0], %[v_tmp0], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
|
||||
"v_fma_f32 %[v_tmp1], %[v_tmp1], %[s_c1], %[v_c2] ; c1*x*x+c2\n"
|
||||
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[v_y0] ; x*(c1*x*x+c2)\n"
|
||||
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[v_y1] ; x*(c1*x*x+c2)\n"
|
||||
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
|
||||
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[s_log2e] ; log2e*x*(c1*x*x+c2)\n"
|
||||
"v_exp_f32 %[v_tmp0], %[v_tmp0] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
|
||||
"v_exp_f32 %[v_tmp1], %[v_tmp1] ; emu = exp2(log2e*x*(c1*x*x+c2))\n"
|
||||
"v_add_f32 %[v_tmp0], %[v_tmp0], 1.0 ; emu+1.0f\n"
|
||||
"v_add_f32 %[v_tmp1], %[v_tmp1], 1.0 ; emu+1.0f\n"
|
||||
"v_rcp_f32 %[v_tmp0], %[v_tmp0] ; 1/(emu+1.0f)\n"
|
||||
"v_rcp_f32 %[v_tmp1], %[v_tmp1] ; 1/(emu+1.0f)\n"
|
||||
"v_mul_f32 %[v_y0], %[v_tmp0], %[v_y0] ; x * 1/(emu+1f)\n"
|
||||
"v_mul_f32 %[v_y1], %[v_tmp1], %[v_y1] ; x * 1/(emu+1f)\n"
|
||||
: [v_y0] "+v"(y0),
|
||||
[v_y1] "+v"(y1),
|
||||
[v_c2] "+v"(c2),
|
||||
// NOTE! it is totally possible that c2/y0/y1 share same register, they are all local
|
||||
// tmp variables we need to expicitly hint compiler they may read+write, to allow
|
||||
// allocate different register , the side effect is c2=** may issue for every such
|
||||
// inline asm block
|
||||
[v_tmp0] "+v"(tmp0),
|
||||
[v_tmp1] "+v"(tmp1)
|
||||
: [s_c1] "s"(c1), [s_log2e] "s"(log2e_)
|
||||
:);
|
||||
y.x = y0;
|
||||
y.y = y1;
|
||||
}
|
||||
};
|
||||
|
||||
// https://paperswithcode.com/method/gelu
|
||||
// y = 0.5*x*(1+erf(x/sqrt(2)))
|
||||
struct Gelu
|
||||
|
||||
10
include/ck_tile/ops/flatmm.hpp
Normal file
10
include/ck_tile/ops/flatmm.hpp
Normal file
@@ -0,0 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
@@ -0,0 +1,615 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A async load to LDS, B direct to AGPR
|
||||
// B matrix preshuffled in br*kr*w
|
||||
// require 4 wave, occupancy=1c
|
||||
// agpr useage:256
|
||||
// vgpr usage:64(A local) + 64(acc) + 8(os_a) + 8(os_b) = 144 (rem:112)
|
||||
//
|
||||
// for this gemm, 4 16x16x16 transposed layout
|
||||
// input A vpgpr layout
|
||||
// v0-v15: [ 0:15](gemm_m)x128(gemm_k)
|
||||
// v16-v31: [16:31](gemm_m)x128(gemm_k)
|
||||
|
||||
// input B vpgpr layout
|
||||
// v0-v15: [ 0: 15](gemm_n)x128(gemm_k)
|
||||
// v16-v31: [ 64: 79](gemm_n)x128(gemm_k)
|
||||
// ......................
|
||||
// v111-v127: [448:463](gemm_n)x128(gemm_k)
|
||||
|
||||
// output C vpgpr layout
|
||||
// v0-v3 : [ 0:15](gemm_m)x[ 0: 15](gemm_n)
|
||||
// v4-v7 : [16:31](gemm_m)x[ 0: 15](gemm_n)
|
||||
// v8-v11: [ 0:15](gemm_m)x[64: 79](gemm_n)
|
||||
// v12-v15: [16:31](gemm_m)x[64: 79](gemm_n)
|
||||
// ......................
|
||||
// v56-v59: [ 0:15](gemm_m)x[448:463](gemm_n)
|
||||
// v60-v63: [16:31](gemm_m)x[448:463](gemm_n)
|
||||
struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
{
|
||||
static constexpr index_t Block_M = 32;
|
||||
static constexpr index_t Block_N = 512;
|
||||
static constexpr index_t Block_K = 128;
|
||||
|
||||
static constexpr index_t WarpPerBlock_M = 1;
|
||||
static constexpr index_t WarpPerBlock_N = 4;
|
||||
static constexpr index_t WarpPerBlock_K = 1;
|
||||
|
||||
static constexpr index_t NumWarps = 4;
|
||||
|
||||
static constexpr index_t Warp_M = 16;
|
||||
static constexpr index_t Warp_N = 16;
|
||||
static constexpr index_t Warp_K = 32; // 16 * SubKPacks
|
||||
|
||||
static constexpr index_t BlockSize = 256;
|
||||
|
||||
static constexpr index_t SubKPacks = 2; // this is used to gurantee every threads can do dwordx4
|
||||
|
||||
// TODO: note Nr/Kr/W need consider SubKPacks
|
||||
static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
|
||||
static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
|
||||
static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
|
||||
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8
|
||||
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4
|
||||
|
||||
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<2, 1>, // !! note here is different
|
||||
sequence<0, 0>>{};
|
||||
|
||||
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution;
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
return c_block_dstr;
|
||||
}
|
||||
|
||||
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
|
||||
{
|
||||
using CDataType = float;
|
||||
constexpr auto c_block_dstr = MakeCBlockDist();
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
|
||||
{
|
||||
// A async->LDS
|
||||
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
|
||||
constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
|
||||
constexpr index_t KPad = KPack_; // pad between warps
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
if constexpr(LanesPerK >= warpSize)
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % warpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / warpSize;
|
||||
if constexpr(wavesPerK > NumWarps)
|
||||
{
|
||||
// TODO: need multiple issues along K to load all data
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t wavesPerM = NumWarps / wavesPerK;
|
||||
constexpr index_t NumIssues = Block_M / wavesPerM;
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<wavesPerM>{}, // m1
|
||||
number<wavesPerK>{}, // k0
|
||||
number<warpSize>{}, // k1
|
||||
number<KVector>{}), // k2
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // k0
|
||||
number<KVector>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<NumIssues>{}),
|
||||
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
|
||||
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
return lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(warpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
|
||||
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<LaneGroups>{}, // m1
|
||||
number<NumWarps>{}, // m2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<Block_K>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // m2
|
||||
number<KVector>{}, // k0
|
||||
number<1>{}), // k1
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<NumIssues>{}),
|
||||
make_pass_through_transform(number<NumWarps>{}),
|
||||
make_merge_transform(make_tuple(
|
||||
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
return lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
}
|
||||
|
||||
// template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
|
||||
{
|
||||
// load from LDS to register, every wave has same layout
|
||||
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
|
||||
constexpr index_t KPad = KPack_; // pad between warps
|
||||
|
||||
constexpr index_t kAMLane = 16;
|
||||
constexpr index_t kABKLane = 4;
|
||||
constexpr index_t kABKPerLane = 4;
|
||||
constexpr index_t kKIter = 2;
|
||||
static_assert(KPack_ == (kABKPerLane * kKIter));
|
||||
|
||||
constexpr auto lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<Repeat_M>{}, // m0 y
|
||||
number<kAMLane>{}, // m1 p
|
||||
number<Repeat_K>{}, // k0 y
|
||||
number<kABKLane>{}, // k1 p
|
||||
number<KPack_>{}), // k2 y-vector
|
||||
make_tuple(number<kAMLane*(Block_K + KPad)>{}, // m0
|
||||
number<Block_K + KPad>{}, // m1
|
||||
number<kABKLane * KPack_>{}, // k0
|
||||
number<KPack_>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KPack_>{}, // lds load vector
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(make_tuple(number<Repeat_M>{}, number<kAMLane>{})),
|
||||
make_merge_transform(
|
||||
make_tuple(number<Repeat_K>{}, number<kABKLane>{}, number<KPack_>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_desc_m_k;
|
||||
}
|
||||
|
||||
static constexpr auto GetGemm_AWarpEnc()
|
||||
{
|
||||
constexpr index_t kAMLane = 16;
|
||||
constexpr index_t kABKLane = 4;
|
||||
constexpr index_t kABKPerLane = 4;
|
||||
constexpr index_t kKIter = 2;
|
||||
|
||||
using enc_ = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<kAMLane>, sequence<kABKLane, kABKPerLane * kKIter>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<1>>;
|
||||
return enc_{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return 32 * (128 + 8) * sizeof(bf16_t);
|
||||
}
|
||||
};
|
||||
|
||||
struct Flatmm_32x512x128_1x4x1_16x16x32_BF16 : public Flatmm_32x512x128_1x4x1_16x16x32_Base
|
||||
{
|
||||
using ADataType = bf16_t;
|
||||
using BDataType = bf16_t;
|
||||
|
||||
// TODO: need paired with tile_window_linear!
|
||||
// TODO: need call init_raw() before call this function!
|
||||
template <typename ARes, typename ACoords, typename BRes, typename BCoords>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const ARes& res_a,
|
||||
const ACoords& cached_coords_a,
|
||||
const BRes& res_b,
|
||||
const BCoords& cached_coords_b,
|
||||
CK_TILE_LDS_ADDR void* smem,
|
||||
index_t k,
|
||||
index_t tile_offset_a, // for each tile, the offset to move for each unroll
|
||||
index_t tile_offset_b) // for each tile, the offset to move for each unroll
|
||||
{
|
||||
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
|
||||
static_assert(BCoords::size() == Repeat_N);
|
||||
|
||||
auto a_sst = make_tile_window(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
|
||||
MakeLdsStoreDesc_A().get_lengths(),
|
||||
{0, 0, 0});
|
||||
|
||||
auto a_sld = [&]() {
|
||||
constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
|
||||
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
|
||||
sequence<WarpPerBlock_N>,
|
||||
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
|
||||
return make_tile_window_linear(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
|
||||
MakeLdsLoadDesc_A().get_lengths(),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(a_block_dstr_encode));
|
||||
}();
|
||||
|
||||
const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
|
||||
const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
|
||||
|
||||
const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
|
||||
constexpr auto smem_buf_size =
|
||||
MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
|
||||
static_assert(a_sld.get_num_of_access() == 8);
|
||||
constexpr auto sld_os = generate_tuple(
|
||||
[&](auto i_access) {
|
||||
return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
|
||||
},
|
||||
number<a_sld.get_num_of_access()>{});
|
||||
|
||||
index_t loop_cnt = k / Block_K;
|
||||
|
||||
// this is the acc thread buffer
|
||||
fp32x4_t v_acc[16]{.0f};
|
||||
|
||||
// B nr->kr
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Winline-asm"
|
||||
// clang-format off
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
|
||||
#undef CK_TILE_FLATMM_UK_MFMA
|
||||
: [s_loop_cnt]"+s"(loop_cnt),
|
||||
[v_acc_0]"+v"(v_acc[0]),
|
||||
[v_acc_1]"+v"(v_acc[1]),
|
||||
[v_acc_2]"+v"(v_acc[2]),
|
||||
[v_acc_3]"+v"(v_acc[3]),
|
||||
[v_acc_4]"+v"(v_acc[4]),
|
||||
[v_acc_5]"+v"(v_acc[5]),
|
||||
[v_acc_6]"+v"(v_acc[6]),
|
||||
[v_acc_7]"+v"(v_acc[7]),
|
||||
[v_acc_8]"+v"(v_acc[8]),
|
||||
[v_acc_9]"+v"(v_acc[9]),
|
||||
[v_acc_10]"+v"(v_acc[10]),
|
||||
[v_acc_11]"+v"(v_acc[11]),
|
||||
[v_acc_12]"+v"(v_acc[12]),
|
||||
[v_acc_13]"+v"(v_acc[13]),
|
||||
[v_acc_14]"+v"(v_acc[14]),
|
||||
[v_acc_15]"+v"(v_acc[15]),
|
||||
[s_mem_]"+r"(smem)
|
||||
: [s_res_a0]"s"(res_a[0]),
|
||||
[s_res_a1]"s"(res_a[1]),
|
||||
[s_res_a2]"s"(res_a[2]),
|
||||
[s_res_a3]"s"(res_a[3]),
|
||||
[s_res_b0]"s"(res_b[0]),
|
||||
[s_res_b1]"s"(res_b[1]),
|
||||
[s_res_b2]"s"(res_b[2]),
|
||||
[s_res_b3]"s"(res_b[3]),
|
||||
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),
|
||||
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),
|
||||
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),
|
||||
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))),
|
||||
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))),
|
||||
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))),
|
||||
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))),
|
||||
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))),
|
||||
|
||||
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
|
||||
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
|
||||
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
|
||||
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
|
||||
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
|
||||
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
|
||||
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
|
||||
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
|
||||
|
||||
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),
|
||||
[s_m0_init]"s"(m0_init_value),
|
||||
[s_size_per_issue]"s"(size_per_issue),
|
||||
[smem_sz]"n"(smem_buf_size), //(smem_buf_size),
|
||||
[sld_os_0]"n"(sld_os[number<0>{}].value),
|
||||
[sld_os_1]"n"(sld_os[number<1>{}].value),
|
||||
[sld_os_2]"n"(sld_os[number<2>{}].value),
|
||||
[sld_os_3]"n"(sld_os[number<3>{}].value),
|
||||
[sld_os_4]"n"(sld_os[number<4>{}].value),
|
||||
[sld_os_5]"n"(sld_os[number<5>{}].value),
|
||||
[sld_os_6]"n"(sld_os[number<6>{}].value),
|
||||
[sld_os_7]"n"(sld_os[number<7>{}].value),
|
||||
[s_tile_os_a]"s"(tile_offset_a_bytes),
|
||||
[s_tile_os_b]"s"(tile_offset_b_bytes)
|
||||
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
|
||||
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
|
||||
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
|
||||
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
|
||||
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
|
||||
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
|
||||
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
|
||||
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
|
||||
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
|
||||
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
|
||||
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
|
||||
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
|
||||
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
|
||||
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
|
||||
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
|
||||
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
|
||||
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
|
||||
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
|
||||
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
|
||||
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
|
||||
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
|
||||
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
|
||||
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
|
||||
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
|
||||
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
|
||||
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
|
||||
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
|
||||
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
|
||||
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
|
||||
"a252", "a253", "a254", "a255",
|
||||
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
|
||||
"s86", // s86 as tmp
|
||||
"v64", "v65", "v66", "v67", "v68", "v69",
|
||||
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79",
|
||||
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89",
|
||||
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99",
|
||||
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107",
|
||||
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115",
|
||||
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123",
|
||||
"v124", "v125", "v126", "v127"
|
||||
);
|
||||
// clang-format on
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
// return local scratch
|
||||
auto c = MakeCBlockTile();
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
|
||||
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
|
||||
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
|
||||
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
};
|
||||
|
||||
struct Flatmm_32x512x128_1x4x1_16x16x32_FP16 : public Flatmm_32x512x128_1x4x1_16x16x32_Base
|
||||
{
|
||||
using ADataType = fp16_t;
|
||||
using BDataType = fp16_t;
|
||||
|
||||
// TODO: need paired with tile_window_linear!
|
||||
// TODO: need call init_raw() before call this function!
|
||||
template <typename ARes, typename ACoords, typename BRes, typename BCoords>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const ARes& res_a,
|
||||
const ACoords& cached_coords_a,
|
||||
const BRes& res_b,
|
||||
const BCoords& cached_coords_b,
|
||||
CK_TILE_LDS_ADDR void* smem,
|
||||
index_t k,
|
||||
index_t tile_offset_a, // for each tile, the offset to move for each unroll
|
||||
index_t tile_offset_b) // for each tile, the offset to move for each unroll
|
||||
{
|
||||
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
|
||||
static_assert(BCoords::size() == Repeat_N);
|
||||
|
||||
auto a_sst = make_tile_window(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsStoreDesc_A()),
|
||||
MakeLdsStoreDesc_A().get_lengths(),
|
||||
{0, 0, 0});
|
||||
|
||||
auto a_sld = [&]() {
|
||||
constexpr auto a_warp_enc_ = GetGemm_AWarpEnc();
|
||||
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
|
||||
sequence<WarpPerBlock_N>,
|
||||
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_K>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode =
|
||||
detail::make_embed_tile_distribution_encoding(a_outer_dstr_enc, a_warp_enc_);
|
||||
return make_tile_window_linear(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem), MakeLdsLoadDesc_A()),
|
||||
MakeLdsLoadDesc_A().get_lengths(),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(a_block_dstr_encode));
|
||||
}();
|
||||
|
||||
const index_t tile_offset_a_bytes = tile_offset_a * sizeof(ADataType);
|
||||
const index_t tile_offset_b_bytes = tile_offset_b * sizeof(BDataType);
|
||||
|
||||
const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
|
||||
constexpr auto smem_buf_size =
|
||||
MakeLdsLoadDesc_A().get_element_space_size() * sizeof(ADataType);
|
||||
static_assert(a_sld.get_num_of_access() == 8);
|
||||
constexpr auto sld_os = generate_tuple(
|
||||
[&](auto i_access) {
|
||||
return number<a_sld.get_bottom_linear_offset(i_access) * sizeof(ADataType)>{};
|
||||
},
|
||||
number<a_sld.get_num_of_access()>{});
|
||||
|
||||
index_t loop_cnt = k / Block_K;
|
||||
|
||||
// this is the acc thread buffer
|
||||
fp32x4_t v_acc[16]{.0f};
|
||||
|
||||
// B nr->kr
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Winline-asm"
|
||||
// clang-format off
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
|
||||
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
|
||||
#undef CK_TILE_FLATMM_UK_MFMA
|
||||
: [s_loop_cnt]"+s"(loop_cnt),
|
||||
[v_acc_0]"+v"(v_acc[0]),
|
||||
[v_acc_1]"+v"(v_acc[1]),
|
||||
[v_acc_2]"+v"(v_acc[2]),
|
||||
[v_acc_3]"+v"(v_acc[3]),
|
||||
[v_acc_4]"+v"(v_acc[4]),
|
||||
[v_acc_5]"+v"(v_acc[5]),
|
||||
[v_acc_6]"+v"(v_acc[6]),
|
||||
[v_acc_7]"+v"(v_acc[7]),
|
||||
[v_acc_8]"+v"(v_acc[8]),
|
||||
[v_acc_9]"+v"(v_acc[9]),
|
||||
[v_acc_10]"+v"(v_acc[10]),
|
||||
[v_acc_11]"+v"(v_acc[11]),
|
||||
[v_acc_12]"+v"(v_acc[12]),
|
||||
[v_acc_13]"+v"(v_acc[13]),
|
||||
[v_acc_14]"+v"(v_acc[14]),
|
||||
[v_acc_15]"+v"(v_acc[15]),
|
||||
[s_mem_]"+r"(smem)
|
||||
: [s_res_a0]"s"(res_a[0]),
|
||||
[s_res_a1]"s"(res_a[1]),
|
||||
[s_res_a2]"s"(res_a[2]),
|
||||
[s_res_a3]"s"(res_a[3]),
|
||||
[s_res_b0]"s"(res_b[0]),
|
||||
[s_res_b1]"s"(res_b[1]),
|
||||
[s_res_b2]"s"(res_b[2]),
|
||||
[s_res_b3]"s"(res_b[3]),
|
||||
[v_os_a0]"v"(static_cast<index_t>(cached_coords_a[number<0>{}] * sizeof(ADataType))),
|
||||
[v_os_a1]"v"(static_cast<index_t>(cached_coords_a[number<1>{}] * sizeof(ADataType))),
|
||||
[v_os_a2]"v"(static_cast<index_t>(cached_coords_a[number<2>{}] * sizeof(ADataType))),
|
||||
[v_os_a3]"v"(static_cast<index_t>(cached_coords_a[number<3>{}] * sizeof(ADataType))),
|
||||
[v_os_a4]"v"(static_cast<index_t>(cached_coords_a[number<4>{}] * sizeof(ADataType))),
|
||||
[v_os_a5]"v"(static_cast<index_t>(cached_coords_a[number<5>{}] * sizeof(ADataType))),
|
||||
[v_os_a6]"v"(static_cast<index_t>(cached_coords_a[number<6>{}] * sizeof(ADataType))),
|
||||
[v_os_a7]"v"(static_cast<index_t>(cached_coords_a[number<7>{}] * sizeof(ADataType))),
|
||||
|
||||
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
|
||||
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
|
||||
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
|
||||
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
|
||||
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
|
||||
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
|
||||
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
|
||||
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
|
||||
|
||||
[v_os_slda]"v"(static_cast<index_t>(a_sld.cached_coords_[number<0>{}].get_offset() * sizeof(ADataType))),
|
||||
[s_m0_init]"s"(m0_init_value),
|
||||
[s_size_per_issue]"s"(size_per_issue),
|
||||
[smem_sz]"n"(smem_buf_size), //(smem_buf_size),
|
||||
[sld_os_0]"n"(sld_os[number<0>{}].value),
|
||||
[sld_os_1]"n"(sld_os[number<1>{}].value),
|
||||
[sld_os_2]"n"(sld_os[number<2>{}].value),
|
||||
[sld_os_3]"n"(sld_os[number<3>{}].value),
|
||||
[sld_os_4]"n"(sld_os[number<4>{}].value),
|
||||
[sld_os_5]"n"(sld_os[number<5>{}].value),
|
||||
[sld_os_6]"n"(sld_os[number<6>{}].value),
|
||||
[sld_os_7]"n"(sld_os[number<7>{}].value),
|
||||
[s_tile_os_a]"s"(tile_offset_a_bytes),
|
||||
[s_tile_os_b]"s"(tile_offset_b_bytes)
|
||||
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
|
||||
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
|
||||
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
|
||||
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
|
||||
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
|
||||
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
|
||||
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
|
||||
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
|
||||
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
|
||||
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
|
||||
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
|
||||
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
|
||||
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
|
||||
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
|
||||
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
|
||||
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
|
||||
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
|
||||
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
|
||||
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
|
||||
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
|
||||
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
|
||||
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
|
||||
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
|
||||
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
|
||||
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
|
||||
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
|
||||
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
|
||||
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
|
||||
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
|
||||
"a252", "a253", "a254", "a255",
|
||||
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
|
||||
"s86", // s86 as tmp
|
||||
"v64", "v65", "v66", "v67", "v68", "v69",
|
||||
"v70", "v71", "v72", "v73", "v74", "v75", "v76", "v77", "v78", "v79",
|
||||
"v80", "v81", "v82", "v83", "v84", "v85", "v86", "v87", "v88", "v89",
|
||||
"v90", "v91", "v92", "v93", "v94", "v95", "v96", "v97", "v98", "v99",
|
||||
"v100", "v101", "v102", "v103", "v104", "v105", "v106", "v107",
|
||||
"v108", "v109", "v110", "v111", "v112", "v113", "v114", "v115",
|
||||
"v116", "v117", "v118", "v119", "v120", "v121", "v122", "v123",
|
||||
"v124", "v125", "v126", "v127"
|
||||
);
|
||||
// clang-format on
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
// return local scratch
|
||||
auto c = MakeCBlockTile();
|
||||
for(auto i = 0; i < 16; i++)
|
||||
{
|
||||
c.get_thread_buffer()[4 * i + 0] = v_acc[i].x;
|
||||
c.get_thread_buffer()[4 * i + 1] = v_acc[i].y;
|
||||
c.get_thread_buffer()[4 * i + 2] = v_acc[i].z;
|
||||
c.get_thread_buffer()[4 * i + 3] = v_acc[i].w;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,562 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// "S"tream update output along "N"
|
||||
// A in smem, B load from global
|
||||
// require 4 wave, occupancy=1c
|
||||
struct FlatmmSn_32x128x512_1x4x1_16x16x32_Base
|
||||
{
|
||||
static constexpr index_t Block_M = 32;
|
||||
static constexpr index_t Block_N = 128;
|
||||
static constexpr index_t Block_K = 512;
|
||||
|
||||
static constexpr index_t WarpPerBlock_M = 1;
|
||||
static constexpr index_t WarpPerBlock_N = 4;
|
||||
static constexpr index_t WarpPerBlock_K = 1;
|
||||
|
||||
static constexpr index_t Warp_M = 16;
|
||||
static constexpr index_t Warp_N = 16;
|
||||
static constexpr index_t Warp_K = 32;
|
||||
|
||||
static constexpr index_t BlockSize = 256;
|
||||
|
||||
// static constexpr index_t KPack = 2; // this is used to gurantee every threads can do dwordx4
|
||||
|
||||
// TODO: note Nr/Kr/W need consider KPack
|
||||
static constexpr index_t Block_W = Warp_N * Warp_K; // 512 element
|
||||
static constexpr index_t Block_Nr = Block_N / Warp_N; // 32 element, 4 per wave
|
||||
static constexpr index_t Block_Kr = Block_K / Warp_K; // 4
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M / (Warp_M * WarpPerBlock_M); // 2
|
||||
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 2
|
||||
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 16
|
||||
|
||||
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<2, 1>, // !! note here is different
|
||||
sequence<0, 0>>{};
|
||||
|
||||
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution;
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
return c_block_dstr;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
// y y p p p y
|
||||
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
|
||||
// but order is N0*M0*Nv
|
||||
// in LDS we need store as
|
||||
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
|
||||
// y y wave-id lid/16 lid%16 v
|
||||
return 2 * 2 * 4 * 4 * (16 * 4 + 4) * sizeof(bf16_t);
|
||||
}
|
||||
};
|
||||
|
||||
struct FlatmmSn_32x128x512_1x4x1_16x16x32_BF16 : public FlatmmSn_32x128x512_1x4x1_16x16x32_Base
|
||||
{
|
||||
using BDataType = bf16_t;
|
||||
using ODataType = bf16_t;
|
||||
|
||||
// TODO: need paired with tile_window_linear!
|
||||
// TODO: need call init_raw() before call this function!
|
||||
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
|
||||
template <typename BRes,
|
||||
typename BCoords,
|
||||
typename ORes,
|
||||
typename OCoords,
|
||||
typename OFlags,
|
||||
typename ScaleTensor>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const BRes& res_b,
|
||||
const BCoords& cached_coords_b,
|
||||
const ORes& res_o,
|
||||
const OCoords& cached_coords_o,
|
||||
const OFlags& o_flags, // this should be in sgpr
|
||||
CK_TILE_LDS_ADDR void* smem,
|
||||
index_t n, // loop along n dim
|
||||
const ScaleTensor& scale_,
|
||||
index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
|
||||
index_t tile_offset_o)
|
||||
{
|
||||
static_assert(BCoords::size() == 8); // 8
|
||||
static_assert(OCoords::size() == 8);
|
||||
|
||||
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
|
||||
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
|
||||
|
||||
static_assert(ScaleTensor::size() == 2);
|
||||
float s0 = scale_[number<0>{}];
|
||||
float s1 = scale_[number<1>{}];
|
||||
|
||||
index_t loop_cnt = n / Block_N;
|
||||
|
||||
register float v_c0 asm("v64");
|
||||
register float v_c1 asm("v65");
|
||||
register float v_c2 asm("v66");
|
||||
register float v_c3 asm("v67");
|
||||
register float v_c4 asm("v68");
|
||||
register float v_c5 asm("v69");
|
||||
register float v_c6 asm("v70");
|
||||
register float v_c7 asm("v71");
|
||||
register float v_c8 asm("v72");
|
||||
register float v_c9 asm("v73");
|
||||
register float v_c10 asm("v74");
|
||||
register float v_c11 asm("v75");
|
||||
register float v_c12 asm("v76");
|
||||
register float v_c13 asm("v77");
|
||||
register float v_c14 asm("v78");
|
||||
register float v_c15 asm("v79");
|
||||
register float v_c16 asm("v80");
|
||||
register float v_c17 asm("v81");
|
||||
register float v_c18 asm("v82");
|
||||
register float v_c19 asm("v83");
|
||||
register float v_c20 asm("v84");
|
||||
register float v_c21 asm("v85");
|
||||
register float v_c22 asm("v86");
|
||||
register float v_c23 asm("v87");
|
||||
register float v_c24 asm("v88");
|
||||
register float v_c25 asm("v89");
|
||||
register float v_c26 asm("v90");
|
||||
register float v_c27 asm("v91");
|
||||
register float v_c28 asm("v92");
|
||||
register float v_c29 asm("v93");
|
||||
register float v_c30 asm("v94");
|
||||
register float v_c31 asm("v95");
|
||||
int32_t nan_hi = 0x7fff0000;
|
||||
int32_t nan_lo = 0x00007fff;
|
||||
|
||||
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
|
||||
// every threads need 8xK in contiguous register
|
||||
// ... and every wave need the same data
|
||||
int lane_id = threadIdx.x % 64;
|
||||
int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
|
||||
sld_y_os *= 2;
|
||||
|
||||
// y y p p p y
|
||||
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
|
||||
// but order is N0*M0*Nv
|
||||
// in LDS we need store as
|
||||
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
|
||||
// y y wave-id lid/16 lid%16 v
|
||||
// sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
|
||||
int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
|
||||
sfl_sst *= 2;
|
||||
|
||||
// from LDS we need load as
|
||||
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
|
||||
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
|
||||
// sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
|
||||
int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
|
||||
sfl_sld *= 2;
|
||||
|
||||
// B nr->kr
|
||||
// clang-format off
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Winline-asm"
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
|
||||
#undef CK_TILE_FLATMM_UK_MFMA
|
||||
:[smem_]"+r"(smem),
|
||||
[s_loop_cnt]"+s"(loop_cnt),
|
||||
[c0]"+v" (v_c0),
|
||||
[c1]"+v" (v_c1),
|
||||
[c2]"+v" (v_c2),
|
||||
[c3]"+v" (v_c3),
|
||||
[c4]"+v" (v_c4),
|
||||
[c5]"+v" (v_c5),
|
||||
[c6]"+v" (v_c6),
|
||||
[c7]"+v" (v_c7),
|
||||
[c8]"+v" (v_c8),
|
||||
[c9]"+v" (v_c9),
|
||||
[c10]"+v"(v_c10),
|
||||
[c11]"+v"(v_c11),
|
||||
[c12]"+v"(v_c12),
|
||||
[c13]"+v"(v_c13),
|
||||
[c14]"+v"(v_c14),
|
||||
[c15]"+v"(v_c15),
|
||||
[c16]"+v"(v_c16),
|
||||
[c17]"+v"(v_c17),
|
||||
[c18]"+v"(v_c18),
|
||||
[c19]"+v"(v_c19),
|
||||
[c20]"+v"(v_c20),
|
||||
[c21]"+v"(v_c21),
|
||||
[c22]"+v"(v_c22),
|
||||
[c23]"+v"(v_c23),
|
||||
[c24]"+v"(v_c24),
|
||||
[c25]"+v"(v_c25),
|
||||
[c26]"+v"(v_c26),
|
||||
[c27]"+v"(v_c27),
|
||||
[c28]"+v"(v_c28),
|
||||
[c29]"+v"(v_c29),
|
||||
[c30]"+v"(v_c30),
|
||||
[c31]"+v"(v_c31)
|
||||
:
|
||||
[sld_a_base]"n"(0),
|
||||
[shfl_base]"n"(0),
|
||||
[v_sld_y_os]"v"(sld_y_os),
|
||||
[v_sfl_sld]"v"(sfl_sld),
|
||||
[v_sfl_sst]"v"(sfl_sst),
|
||||
[s_res_o0]"s"(res_o[0]),
|
||||
[s_res_o1]"s"(res_o[1]),
|
||||
//[s_res_o2]"s"(res_o[2]),
|
||||
//[s_res_o3]"s"(res_o[3]),
|
||||
[s_res_b0]"s"(res_b[0]),
|
||||
[s_res_b1]"s"(res_b[1]),
|
||||
[s_res_b2]"s"(res_b[2]),
|
||||
[s_res_b3]"s"(res_b[3]),
|
||||
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
|
||||
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
|
||||
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
|
||||
[v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
|
||||
[v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
|
||||
[v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
|
||||
[v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
|
||||
[v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
|
||||
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
|
||||
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
|
||||
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
|
||||
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
|
||||
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
|
||||
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
|
||||
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
|
||||
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
|
||||
|
||||
[s_tile_os_o]"s"(tile_stride_o_bytes),
|
||||
[s_tile_os_b]"s"(tile_stride_b_bytes),
|
||||
[scale_0]"v"(s0),
|
||||
[scale_1]"v"(s1),
|
||||
[v_nan_lo]"v"(nan_lo),
|
||||
[v_nan_hi]"v"(nan_hi),
|
||||
[s_execflag_0]"s"(o_flags[number<0>{}]),
|
||||
[s_execflag_1]"s"(o_flags[number<1>{}]),
|
||||
[s_execflag_2]"s"(o_flags[number<2>{}]),
|
||||
[s_execflag_3]"s"(o_flags[number<3>{}]),
|
||||
[s_execflag_4]"s"(o_flags[number<4>{}]),
|
||||
[s_execflag_5]"s"(o_flags[number<5>{}]),
|
||||
[s_execflag_6]"s"(o_flags[number<6>{}]),
|
||||
[s_execflag_7]"s"(o_flags[number<7>{}])
|
||||
:
|
||||
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
|
||||
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
|
||||
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
|
||||
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
|
||||
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
|
||||
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
|
||||
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
|
||||
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
|
||||
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
|
||||
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
|
||||
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
|
||||
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
|
||||
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
|
||||
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
|
||||
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
|
||||
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
|
||||
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
|
||||
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
|
||||
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
|
||||
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
|
||||
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
|
||||
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
|
||||
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
|
||||
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
|
||||
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
|
||||
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
|
||||
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
|
||||
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
|
||||
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
|
||||
"a252", "a253", "a254", "a255",
|
||||
"s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
|
||||
"s36", "s37",
|
||||
"v50", "v54", "v55",
|
||||
"v64","v65","v66","v67","v68","v69","v70","v71",
|
||||
"v72","v73","v74","v75","v76","v77","v78","v79",
|
||||
"v80","v81","v82","v83","v84","v85","v86","v87",
|
||||
"v88","v89","v90","v91","v92","v93","v94","v95",
|
||||
"v128", "v129", "v130", "v131",
|
||||
"v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
|
||||
"v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
|
||||
"v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
|
||||
"v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
|
||||
"v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
|
||||
"v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
|
||||
"v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
|
||||
"v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
|
||||
"v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
|
||||
"v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
|
||||
"v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
|
||||
"v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
|
||||
"v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
|
||||
"v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
|
||||
"v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
|
||||
"v252", "v253", "v254", "v255"
|
||||
);
|
||||
#pragma clang diagnostic pop
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
struct FlatmmSn_32x128x512_1x4x1_16x16x32_FP16 : public FlatmmSn_32x128x512_1x4x1_16x16x32_Base
|
||||
{
|
||||
using BDataType = bf16_t;
|
||||
using ODataType = bf16_t;
|
||||
|
||||
// TODO: need paired with tile_window_linear!
|
||||
// TODO: need call init_raw() before call this function!
|
||||
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
|
||||
template <typename BRes,
|
||||
typename BCoords,
|
||||
typename ORes,
|
||||
typename OCoords,
|
||||
typename OFlags,
|
||||
typename ScaleTensor>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(const BRes& res_b,
|
||||
const BCoords& cached_coords_b,
|
||||
const ORes& res_o,
|
||||
const OCoords& cached_coords_o,
|
||||
const OFlags& o_flags, // this should be in sgpr
|
||||
CK_TILE_LDS_ADDR void* smem,
|
||||
index_t n, // loop along n dim
|
||||
const ScaleTensor& scale_,
|
||||
index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
|
||||
index_t tile_offset_o)
|
||||
{
|
||||
static_assert(BCoords::size() == 8); // 8
|
||||
static_assert(OCoords::size() == 8);
|
||||
|
||||
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
|
||||
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
|
||||
|
||||
static_assert(ScaleTensor::size() == 2);
|
||||
float s0 = scale_[number<0>{}];
|
||||
float s1 = scale_[number<1>{}];
|
||||
|
||||
index_t loop_cnt = n / Block_N;
|
||||
|
||||
register float v_c0 asm("v64");
|
||||
register float v_c1 asm("v65");
|
||||
register float v_c2 asm("v66");
|
||||
register float v_c3 asm("v67");
|
||||
register float v_c4 asm("v68");
|
||||
register float v_c5 asm("v69");
|
||||
register float v_c6 asm("v70");
|
||||
register float v_c7 asm("v71");
|
||||
register float v_c8 asm("v72");
|
||||
register float v_c9 asm("v73");
|
||||
register float v_c10 asm("v74");
|
||||
register float v_c11 asm("v75");
|
||||
register float v_c12 asm("v76");
|
||||
register float v_c13 asm("v77");
|
||||
register float v_c14 asm("v78");
|
||||
register float v_c15 asm("v79");
|
||||
register float v_c16 asm("v80");
|
||||
register float v_c17 asm("v81");
|
||||
register float v_c18 asm("v82");
|
||||
register float v_c19 asm("v83");
|
||||
register float v_c20 asm("v84");
|
||||
register float v_c21 asm("v85");
|
||||
register float v_c22 asm("v86");
|
||||
register float v_c23 asm("v87");
|
||||
register float v_c24 asm("v88");
|
||||
register float v_c25 asm("v89");
|
||||
register float v_c26 asm("v90");
|
||||
register float v_c27 asm("v91");
|
||||
register float v_c28 asm("v92");
|
||||
register float v_c29 asm("v93");
|
||||
register float v_c30 asm("v94");
|
||||
register float v_c31 asm("v95");
|
||||
int32_t nan_hi = 0x7fff0000;
|
||||
int32_t nan_lo = 0x00007fff;
|
||||
|
||||
// in smem, the layout is M0(2)*K0(128)*M1(16)*K1(4)
|
||||
// every threads need 8xK in contiguous register
|
||||
// ... and every wave need the same data
|
||||
int lane_id = threadIdx.x % 64;
|
||||
int sld_y_os = (lane_id % 16) * 4 + (lane_id / 16) * 128;
|
||||
sld_y_os *= 2;
|
||||
|
||||
// y y p p p y
|
||||
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
|
||||
// but order is N0*M0*Nv
|
||||
// in LDS we need store as
|
||||
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
|
||||
// y y wave-id lid/16 lid%16 v
|
||||
// sst(v3) = (v0/16*34 + v0%16 * 2 + wid*136) * 4
|
||||
int sfl_sst = (threadIdx.x % 16 * 4) + (threadIdx.x / 16) * (64 + 4);
|
||||
sfl_sst *= 2;
|
||||
|
||||
// from LDS we need load as
|
||||
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16) * Nv(4) + 4)
|
||||
// ( 2 issue) (rem 32-lane) (4 wave*4issue) 2lane*1ussue(pk2)
|
||||
// sld(v4) = v0/2 *34*4 + v0 % 2 *4 + wid*2 *4
|
||||
int sfl_sld = (lane_id % 2) * 2 + (lane_id / 2) * (64 + 4) + (threadIdx.x / 64) * 4;
|
||||
sfl_sld *= 2;
|
||||
|
||||
// B nr->kr
|
||||
// clang-format off
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Winline-asm"
|
||||
asm volatile(
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
|
||||
#include "uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc"
|
||||
#undef CK_TILE_FLATMM_UK_MFMA
|
||||
:[smem_]"+r"(smem),
|
||||
[s_loop_cnt]"+s"(loop_cnt),
|
||||
[c0]"+v" (v_c0),
|
||||
[c1]"+v" (v_c1),
|
||||
[c2]"+v" (v_c2),
|
||||
[c3]"+v" (v_c3),
|
||||
[c4]"+v" (v_c4),
|
||||
[c5]"+v" (v_c5),
|
||||
[c6]"+v" (v_c6),
|
||||
[c7]"+v" (v_c7),
|
||||
[c8]"+v" (v_c8),
|
||||
[c9]"+v" (v_c9),
|
||||
[c10]"+v"(v_c10),
|
||||
[c11]"+v"(v_c11),
|
||||
[c12]"+v"(v_c12),
|
||||
[c13]"+v"(v_c13),
|
||||
[c14]"+v"(v_c14),
|
||||
[c15]"+v"(v_c15),
|
||||
[c16]"+v"(v_c16),
|
||||
[c17]"+v"(v_c17),
|
||||
[c18]"+v"(v_c18),
|
||||
[c19]"+v"(v_c19),
|
||||
[c20]"+v"(v_c20),
|
||||
[c21]"+v"(v_c21),
|
||||
[c22]"+v"(v_c22),
|
||||
[c23]"+v"(v_c23),
|
||||
[c24]"+v"(v_c24),
|
||||
[c25]"+v"(v_c25),
|
||||
[c26]"+v"(v_c26),
|
||||
[c27]"+v"(v_c27),
|
||||
[c28]"+v"(v_c28),
|
||||
[c29]"+v"(v_c29),
|
||||
[c30]"+v"(v_c30),
|
||||
[c31]"+v"(v_c31)
|
||||
:
|
||||
[sld_a_base]"n"(0),
|
||||
[shfl_base]"n"(0),
|
||||
[v_sld_y_os]"v"(sld_y_os),
|
||||
[v_sfl_sld]"v"(sfl_sld),
|
||||
[v_sfl_sst]"v"(sfl_sst),
|
||||
[s_res_o0]"s"(res_o[0]),
|
||||
[s_res_o1]"s"(res_o[1]),
|
||||
//[s_res_o2]"s"(res_o[2]),
|
||||
//[s_res_o3]"s"(res_o[3]),
|
||||
[s_res_b0]"s"(res_b[0]),
|
||||
[s_res_b1]"s"(res_b[1]),
|
||||
[s_res_b2]"s"(res_b[2]),
|
||||
[s_res_b3]"s"(res_b[3]),
|
||||
[v_os_o0]"v"(static_cast<index_t>(cached_coords_o[number<0>{}] * sizeof(ODataType))),
|
||||
[v_os_o1]"v"(static_cast<index_t>(cached_coords_o[number<1>{}] * sizeof(ODataType))),
|
||||
[v_os_o2]"v"(static_cast<index_t>(cached_coords_o[number<2>{}] * sizeof(ODataType))),
|
||||
[v_os_o3]"v"(static_cast<index_t>(cached_coords_o[number<3>{}] * sizeof(ODataType))),
|
||||
[v_os_o4]"v"(static_cast<index_t>(cached_coords_o[number<4>{}] * sizeof(ODataType))),
|
||||
[v_os_o5]"v"(static_cast<index_t>(cached_coords_o[number<5>{}] * sizeof(ODataType))),
|
||||
[v_os_o6]"v"(static_cast<index_t>(cached_coords_o[number<6>{}] * sizeof(ODataType))),
|
||||
[v_os_o7]"v"(static_cast<index_t>(cached_coords_o[number<7>{}] * sizeof(ODataType))),
|
||||
[v_os_b0]"v"(static_cast<index_t>(cached_coords_b[number<0>{}] * sizeof(BDataType))),
|
||||
[v_os_b1]"v"(static_cast<index_t>(cached_coords_b[number<1>{}] * sizeof(BDataType))),
|
||||
[v_os_b2]"v"(static_cast<index_t>(cached_coords_b[number<2>{}] * sizeof(BDataType))),
|
||||
[v_os_b3]"v"(static_cast<index_t>(cached_coords_b[number<3>{}] * sizeof(BDataType))),
|
||||
[v_os_b4]"v"(static_cast<index_t>(cached_coords_b[number<4>{}] * sizeof(BDataType))),
|
||||
[v_os_b5]"v"(static_cast<index_t>(cached_coords_b[number<5>{}] * sizeof(BDataType))),
|
||||
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
|
||||
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
|
||||
|
||||
[s_tile_os_o]"s"(tile_stride_o_bytes),
|
||||
[s_tile_os_b]"s"(tile_stride_b_bytes),
|
||||
[scale_0]"v"(s0),
|
||||
[scale_1]"v"(s1),
|
||||
[v_nan_lo]"v"(nan_lo),
|
||||
[v_nan_hi]"v"(nan_hi),
|
||||
[s_execflag_0]"s"(o_flags[number<0>{}]),
|
||||
[s_execflag_1]"s"(o_flags[number<1>{}]),
|
||||
[s_execflag_2]"s"(o_flags[number<2>{}]),
|
||||
[s_execflag_3]"s"(o_flags[number<3>{}]),
|
||||
[s_execflag_4]"s"(o_flags[number<4>{}]),
|
||||
[s_execflag_5]"s"(o_flags[number<5>{}]),
|
||||
[s_execflag_6]"s"(o_flags[number<6>{}]),
|
||||
[s_execflag_7]"s"(o_flags[number<7>{}])
|
||||
:
|
||||
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
|
||||
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
|
||||
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
|
||||
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
|
||||
"a40", "a41", "a42", "a43", "a44", "a45", "a46", "a47", "a48", "a49",
|
||||
"a50", "a51", "a52", "a53", "a54", "a55", "a56", "a57", "a58", "a59",
|
||||
"a60", "a61", "a62", "a63", "a64", "a65", "a66", "a67", "a68", "a69",
|
||||
"a70", "a71", "a72", "a73", "a74", "a75", "a76", "a77", "a78", "a79",
|
||||
"a80", "a81", "a82", "a83", "a84", "a85", "a86", "a87", "a88", "a89",
|
||||
"a90", "a91", "a92", "a93", "a94", "a95", "a96", "a97", "a98", "a99",
|
||||
"a100", "a101", "a102", "a103", "a104", "a105", "a106", "a107",
|
||||
"a108", "a109", "a110", "a111", "a112", "a113", "a114", "a115",
|
||||
"a116", "a117", "a118", "a119", "a120", "a121", "a122", "a123",
|
||||
"a124", "a125", "a126", "a127", "a128", "a129", "a130", "a131",
|
||||
"a132", "a133", "a134", "a135", "a136", "a137", "a138", "a139",
|
||||
"a140", "a141", "a142", "a143", "a144", "a145", "a146", "a147",
|
||||
"a148", "a149", "a150", "a151", "a152", "a153", "a154", "a155",
|
||||
"a156", "a157", "a158", "a159", "a160", "a161", "a162", "a163",
|
||||
"a164", "a165", "a166", "a167", "a168", "a169", "a170", "a171",
|
||||
"a172", "a173", "a174", "a175", "a176", "a177", "a178", "a179",
|
||||
"a180", "a181", "a182", "a183", "a184", "a185", "a186", "a187",
|
||||
"a188", "a189", "a190", "a191", "a192", "a193", "a194", "a195",
|
||||
"a196", "a197", "a198", "a199", "a200", "a201", "a202", "a203",
|
||||
"a204", "a205", "a206", "a207", "a208", "a209", "a210", "a211",
|
||||
"a212", "a213", "a214", "a215", "a216", "a217", "a218", "a219",
|
||||
"a220", "a221", "a222", "a223", "a224", "a225", "a226", "a227",
|
||||
"a228", "a229", "a230", "a231", "a232", "a233", "a234", "a235",
|
||||
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
|
||||
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
|
||||
"a252", "a253", "a254", "a255",
|
||||
"s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
|
||||
"s36", "s37",
|
||||
"v50", "v54", "v55",
|
||||
"v64","v65","v66","v67","v68","v69","v70","v71",
|
||||
"v72","v73","v74","v75","v76","v77","v78","v79",
|
||||
"v80","v81","v82","v83","v84","v85","v86","v87",
|
||||
"v88","v89","v90","v91","v92","v93","v94","v95",
|
||||
"v128", "v129", "v130", "v131",
|
||||
"v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
|
||||
"v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
|
||||
"v148", "v149", "v150", "v151", "v152", "v153", "v154", "v155",
|
||||
"v156", "v157", "v158", "v159", "v160", "v161", "v162", "v163",
|
||||
"v164", "v165", "v166", "v167", "v168", "v169", "v170", "v171",
|
||||
"v172", "v173", "v174", "v175", "v176", "v177", "v178", "v179",
|
||||
"v180", "v181", "v182", "v183", "v184", "v185", "v186", "v187",
|
||||
"v188", "v189", "v190", "v191", "v192", "v193", "v194", "v195",
|
||||
"v196", "v197", "v198", "v199", "v200", "v201", "v202", "v203",
|
||||
"v204", "v205", "v206", "v207", "v208", "v209", "v210", "v211",
|
||||
"v212", "v213", "v214", "v215", "v216", "v217", "v218", "v219",
|
||||
"v220", "v221", "v222", "v223", "v224", "v225", "v226", "v227",
|
||||
"v228", "v229", "v230", "v231", "v232", "v233", "v234", "v235",
|
||||
"v236", "v237", "v238", "v239", "v240", "v241", "v242", "v243",
|
||||
"v244", "v245", "v246", "v247", "v248", "v249", "v250", "v251",
|
||||
"v252", "v253", "v254", "v255"
|
||||
);
|
||||
#pragma clang diagnostic pop
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
10
include/ck_tile/ops/flatmm/block/flatmm_uk_config.hpp
Normal file
10
include/ck_tile/ops/flatmm/block/flatmm_uk_config.hpp
Normal file
@@ -0,0 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#define CK_TILE_FLATMM_UK_MFMA_FP16 0
|
||||
#define CK_TILE_FLATMM_UK_MFMA_BF16 1
|
||||
#define CK_TILE_FLATMM_UK_MFMA_INT8 2
|
||||
#define CK_TILE_FLATMM_UK_MFMA_FP8 3
|
||||
#define CK_TILE_FLATMM_UK_MFMA_BF8 4
|
||||
1
include/ck_tile/ops/flatmm/block/uk/README.md
Normal file
1
include/ck_tile/ops/flatmm/block/uk/README.md
Normal file
@@ -0,0 +1 @@
|
||||
the files under this folder should not be included directly!
|
||||
@@ -0,0 +1,613 @@
|
||||
#ifndef CK_TILE_FLATMM_UK_MFMA
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
#endif
|
||||
|
||||
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
# define _UK_MFMA_ "v_mfma_f32_16x16x16_bf16"
|
||||
|
||||
# define _UK_PK_CVT_(x0_, x1_, y_) \
|
||||
" v_cmp_u_f32 s[36:37], " x0_ ", " x0_ " \n" \
|
||||
" v_add3_u32 v50, " x0_ ", %[v_nan_lo], 1 \n" \
|
||||
" v_cndmask_b32 v54, v50, %[v_nan_hi], s[36:37] \n" \
|
||||
" v_cmp_u_f32 s[36:37], " x1_ ", " x1_ " \n" \
|
||||
" v_add3_u32 v50, " x1_ ", %[v_nan_lo], 1 \n" \
|
||||
" v_cndmask_b32 v55, v50, %[v_nan_hi], s[36:37] \n" \
|
||||
" v_perm_b32 " y_ ", v55, v54, s52 \n"
|
||||
|
||||
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_bf16"
|
||||
|
||||
#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16
|
||||
#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16"
|
||||
|
||||
# define _UK_PK_CVT_(x0_, x1_, y_) \
|
||||
" v_cvt_f16_f32 v54, " x0_ " \n" \
|
||||
" v_cvt_f16_f32 v55, " x1_ " \n" \
|
||||
" v_pack_b32_f16 " y_ ", v54, v55 \n"
|
||||
|
||||
# define _UK_ATOMIC_ADD_ "global_atomic_pk_add_f16"
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
";-------------------------------------------------------------\n"
|
||||
" s_mov_b32 s52, 0x07060302 ; v_perm\n"
|
||||
" s_mov_b64 s[38:39], exec ; save current exec\n"
|
||||
" s_mov_b32 s8, %[s_res_o0] \n"
|
||||
" s_mov_b32 s9, %[s_res_o1] \n"
|
||||
" s_mov_b32 s12, %[s_res_b0] \n"
|
||||
" s_mov_b32 s13, %[s_res_b1] \n"
|
||||
" s_mov_b32 s14, %[s_res_b2] \n"
|
||||
" s_mov_b32 s15, %[s_res_b3] \n"
|
||||
" ds_read_b64 v[128:129], %[v_sld_y_os] offset:0 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[130:131], %[v_sld_y_os] offset:128 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[132:133], %[v_sld_y_os] offset:1024 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[134:135], %[v_sld_y_os] offset:1152 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[136:137], %[v_sld_y_os] offset:2048 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[138:139], %[v_sld_y_os] offset:2176 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[140:141], %[v_sld_y_os] offset:3072 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[142:143], %[v_sld_y_os] offset:3200 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[144:145], %[v_sld_y_os] offset:4096 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[146:147], %[v_sld_y_os] offset:4224 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[148:149], %[v_sld_y_os] offset:5120 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[150:151], %[v_sld_y_os] offset:5248 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[152:153], %[v_sld_y_os] offset:6144 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[154:155], %[v_sld_y_os] offset:6272 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[156:157], %[v_sld_y_os] offset:7168 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[158:159], %[v_sld_y_os] offset:7296 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[160:161], %[v_sld_y_os] offset:8192 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[162:163], %[v_sld_y_os] offset:8320 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[164:165], %[v_sld_y_os] offset:9216 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[166:167], %[v_sld_y_os] offset:9344 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[168:169], %[v_sld_y_os] offset:10240 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[170:171], %[v_sld_y_os] offset:10368 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[172:173], %[v_sld_y_os] offset:11264 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[174:175], %[v_sld_y_os] offset:11392 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[176:177], %[v_sld_y_os] offset:12288 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[178:179], %[v_sld_y_os] offset:12416 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[180:181], %[v_sld_y_os] offset:13312 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[182:183], %[v_sld_y_os] offset:13440 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[184:185], %[v_sld_y_os] offset:14336 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[186:187], %[v_sld_y_os] offset:14464 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[188:189], %[v_sld_y_os] offset:15360 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[190:191], %[v_sld_y_os] offset:15488 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[192:193], %[v_sld_y_os] offset:16384 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[194:195], %[v_sld_y_os] offset:16512 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[196:197], %[v_sld_y_os] offset:17408 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[198:199], %[v_sld_y_os] offset:17536 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[200:201], %[v_sld_y_os] offset:18432 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[202:203], %[v_sld_y_os] offset:18560 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[204:205], %[v_sld_y_os] offset:19456 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[206:207], %[v_sld_y_os] offset:19584 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[208:209], %[v_sld_y_os] offset:20480 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[210:211], %[v_sld_y_os] offset:20608 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[212:213], %[v_sld_y_os] offset:21504 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[214:215], %[v_sld_y_os] offset:21632 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[216:217], %[v_sld_y_os] offset:22528 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[218:219], %[v_sld_y_os] offset:22656 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[220:221], %[v_sld_y_os] offset:23552 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[222:223], %[v_sld_y_os] offset:23680 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[224:225], %[v_sld_y_os] offset:24576 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[226:227], %[v_sld_y_os] offset:24704 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[228:229], %[v_sld_y_os] offset:25600 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[230:231], %[v_sld_y_os] offset:25728 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[232:233], %[v_sld_y_os] offset:26624 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[234:235], %[v_sld_y_os] offset:26752 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[236:237], %[v_sld_y_os] offset:27648 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[238:239], %[v_sld_y_os] offset:27776 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[240:241], %[v_sld_y_os] offset:28672 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[242:243], %[v_sld_y_os] offset:28800 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[244:245], %[v_sld_y_os] offset:29696 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[246:247], %[v_sld_y_os] offset:29824 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[248:249], %[v_sld_y_os] offset:30720 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[250:251], %[v_sld_y_os] offset:30848 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[252:253], %[v_sld_y_os] offset:31744 + %[sld_a_base] \n"
|
||||
" ds_read_b64 v[254:255], %[v_sld_y_os] offset:31872 + %[sld_a_base] \n"
|
||||
" s_waitcnt 0 \n"
|
||||
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n"
|
||||
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
|
||||
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
|
||||
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
|
||||
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen \n"
|
||||
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
|
||||
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
|
||||
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
|
||||
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen \n"
|
||||
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
|
||||
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
|
||||
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
|
||||
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen \n"
|
||||
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
|
||||
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
|
||||
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
|
||||
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen \n"
|
||||
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024 \n"
|
||||
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048 \n"
|
||||
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072 \n"
|
||||
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen \n"
|
||||
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024 \n"
|
||||
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048 \n"
|
||||
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072 \n"
|
||||
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen \n"
|
||||
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen offset:1024 \n"
|
||||
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen offset:2048 \n"
|
||||
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen offset:3072 \n"
|
||||
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen \n"
|
||||
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen offset:1024 \n"
|
||||
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
|
||||
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
|
||||
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
|
||||
" s_add_u32 s12, s86, s12 \n"
|
||||
" s_addc_u32 s13, 0, s13 \n"
|
||||
" s_waitcnt 0 \n"
|
||||
"L_start%=: \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
" s_barrier \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[0:1], v[128:129], 0 \n"
|
||||
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[12:15], 0 offen \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[2:3], v[130:131], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[4:5], v[132:133], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[6:7], v[134:135], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[8:9], v[136:137], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[10:11], v[138:139], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[12:13], v[140:141], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[14:15], v[142:143], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[0:1], v[192:193], 0 \n"
|
||||
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[2:3], v[194:195], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[4:5], v[196:197], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[6:7], v[198:199], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[8:9], v[200:201], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[10:11], v[202:203], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[12:13], v[204:205], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[14:15], v[206:207], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[16:17], v[128:129], 0 \n"
|
||||
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[12:15], 0 offen \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[18:19], v[130:131], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[20:21], v[132:133], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[22:23], v[134:135], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[24:25], v[136:137], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[26:27], v[138:139], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[28:29], v[140:141], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[30:31], v[142:143], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[16:17], v[192:193], 0 \n"
|
||||
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[18:19], v[194:195], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[20:21], v[196:197], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[22:23], v[198:199], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[24:25], v[200:201], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[26:27], v[202:203], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[28:29], v[204:205], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[30:31], v[206:207], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[32:33], v[144:145], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[12:15], 0 offen \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[34:35], v[146:147], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[36:37], v[148:149], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[38:39], v[150:151], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[40:41], v[152:153], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[42:43], v[154:155], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[44:45], v[156:157], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[46:47], v[158:159], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[32:33], v[208:209], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[34:35], v[210:211], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[36:37], v[212:213], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[38:39], v[214:215], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[40:41], v[216:217], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[42:43], v[218:219], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[44:45], v[220:221], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[46:47], v[222:223], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[48:49], v[144:145], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[12:15], 0 offen \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[50:51], v[146:147], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[52:53], v[148:149], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[54:55], v[150:151], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[56:57], v[152:153], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[58:59], v[154:155], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[60:61], v[156:157], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[62:63], v[158:159], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[48:49], v[208:209], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[50:51], v[210:211], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[52:53], v[212:213], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[54:55], v[214:215], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[56:57], v[216:217], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[58:59], v[218:219], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[60:61], v[220:221], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[62:63], v[222:223], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[64:65], v[160:161], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[12:15], 0 offen \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[66:67], v[162:163], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[68:69], v[164:165], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[70:71], v[166:167], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[72:73], v[168:169], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[12:15], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[74:75], v[170:171], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[76:77], v[172:173], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[78:79], v[174:175], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[64:65], v[224:225], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[12:15], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[66:67], v[226:227], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[68:69], v[228:229], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[70:71], v[230:231], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[72:73], v[232:233], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[12:15], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[74:75], v[234:235], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[76:77], v[236:237], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[78:79], v[238:239], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[80:81], v[160:161], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[12:15], 0 offen \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[82:83], v[162:163], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[84:85], v[164:165], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[86:87], v[166:167], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[88:89], v[168:169], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[12:15], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[90:91], v[170:171], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[92:93], v[172:173], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[94:95], v[174:175], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[80:81], v[224:225], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[12:15], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[82:83], v[226:227], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[84:85], v[228:229], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[86:87], v[230:231], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[88:89], v[232:233], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[12:15], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[90:91], v[234:235], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[92:93], v[236:237], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[94:95], v[238:239], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[96:97], v[176:177], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[12:15], 0 offen \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[98:99], v[178:179], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[100:101], v[180:181], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[102:103], v[182:183], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[104:105], v[184:185], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[12:15], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[106:107], v[186:187], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[108:109], v[188:189], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c0], %[c1], %[c2], %[c3]], acc[110:111], v[190:191], [%[c0], %[c1], %[c2], %[c3]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[96:97], v[240:241], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[12:15], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[98:99], v[242:243], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[100:101], v[244:245], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[102:103], v[246:247], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[104:105], v[248:249], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[12:15], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[106:107], v[250:251], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[108:109], v[252:253], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c4], %[c5], %[c6], %[c7]], acc[110:111], v[254:255], [%[c4], %[c5], %[c6], %[c7]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[112:113], v[176:177], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[12:15], 0 offen \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[114:115], v[178:179], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[116:117], v[180:181], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[118:119], v[182:183], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[120:121], v[184:185], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[12:15], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[122:123], v[186:187], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[124:125], v[188:189], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c8], %[c9], %[c10], %[c11]], acc[126:127], v[190:191], [%[c8], %[c9], %[c10], %[c11]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[112:113], v[240:241], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[114:115], v[242:243], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[116:117], v[244:245], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[118:119], v[246:247], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[120:121], v[248:249], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[122:123], v[250:251], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[124:125], v[252:253], [%[c12], %[c13], %[c14], %[c15]] \n"
|
||||
_UK_MFMA_ " [%[c12], %[c13], %[c14], %[c15]], acc[126:127], v[254:255], [%[c12], %[c13], %[c14], %[c15]]\n"
|
||||
" v_mul_f32 %[c0], %[scale_0], %[c0] \n"
|
||||
" v_mul_f32 %[c1], %[scale_0], %[c1] \n"
|
||||
" v_mul_f32 %[c2], %[scale_0], %[c2] \n"
|
||||
" v_mul_f32 %[c3], %[scale_0], %[c3] \n"
|
||||
" v_mul_f32 %[c4], %[scale_1], %[c4] \n"
|
||||
" v_mul_f32 %[c5], %[scale_1], %[c5] \n"
|
||||
" v_mul_f32 %[c6], %[scale_1], %[c6] \n"
|
||||
" v_mul_f32 %[c7], %[scale_1], %[c7] \n"
|
||||
" v_mul_f32 %[c8], %[scale_0], %[c8] \n"
|
||||
" v_mul_f32 %[c9], %[scale_0], %[c9] \n"
|
||||
" v_mul_f32 %[c10], %[scale_0], %[c10] \n"
|
||||
" v_mul_f32 %[c11], %[scale_0], %[c11] \n"
|
||||
" v_mul_f32 %[c12], %[scale_1], %[c12] \n"
|
||||
" v_mul_f32 %[c13], %[scale_1], %[c13] \n"
|
||||
" v_mul_f32 %[c14], %[scale_1], %[c14] \n"
|
||||
" v_mul_f32 %[c15], %[scale_1], %[c15] \n"
|
||||
_UK_PK_CVT_("%[c0]", "%[c1]", "%[c0]")
|
||||
_UK_PK_CVT_("%[c2]", "%[c3]", "%[c1]")
|
||||
_UK_PK_CVT_("%[c4]", "%[c5]", "%[c2]")
|
||||
_UK_PK_CVT_("%[c6]", "%[c7]", "%[c3]")
|
||||
_UK_PK_CVT_("%[c8]", "%[c9]", "%[c4]")
|
||||
_UK_PK_CVT_("%[c10]", "%[c11]", "%[c5]")
|
||||
_UK_PK_CVT_("%[c12]", "%[c13]", "%[c6]")
|
||||
_UK_PK_CVT_("%[c14]", "%[c15]", "%[c7]")
|
||||
" ;------------------------------ \n"
|
||||
" ds_write_b64 %[v_sfl_sst], [%[c0],%[c1]] offset:0 + %[shfl_base] \n"
|
||||
" ds_write_b64 %[v_sfl_sst], [%[c2],%[c3]] offset:4352 + %[shfl_base] \n"
|
||||
" ds_write_b64 %[v_sfl_sst], [%[c4],%[c5]] offset:2176 + %[shfl_base] \n"
|
||||
" ds_write_b64 %[v_sfl_sst], [%[c6],%[c7]] offset:6528 + %[shfl_base] \n"
|
||||
" s_waitcnt lgkmcnt(0) \n"
|
||||
" s_barrier \n"
|
||||
" ds_read_b32 %[c0], %[v_sfl_sld] offset:0 + %[shfl_base] \n"
|
||||
" ds_read_b32 %[c1], %[v_sfl_sld] offset:32 + %[shfl_base] \n"
|
||||
" ds_read_b32 %[c2], %[v_sfl_sld] offset:64 + %[shfl_base] \n"
|
||||
" ds_read_b32 %[c3], %[v_sfl_sld] offset:96 + %[shfl_base] \n"
|
||||
" ds_read_b32 %[c4], %[v_sfl_sld] offset:4352 + %[shfl_base] \n"
|
||||
" ds_read_b32 %[c5], %[v_sfl_sld] offset:4384 + %[shfl_base] \n"
|
||||
" ds_read_b32 %[c6], %[v_sfl_sld] offset:4416 + %[shfl_base] \n"
|
||||
" ds_read_b32 %[c7], %[v_sfl_sld] offset:4448 + %[shfl_base] \n"
|
||||
" s_waitcnt lgkmcnt(0) \n"
|
||||
" s_mov_b64 exec, %[s_execflag_0] \n"
|
||||
_UK_ATOMIC_ADD_ " %[v_os_o0], %[c0], s[8:9] \n"
|
||||
" s_mov_b64 exec, %[s_execflag_1] \n"
|
||||
_UK_ATOMIC_ADD_ " %[v_os_o1], %[c1], s[8:9] \n"
|
||||
" s_mov_b64 exec, %[s_execflag_2] \n"
|
||||
_UK_ATOMIC_ADD_ " %[v_os_o2], %[c2], s[8:9] \n"
|
||||
" s_mov_b64 exec, %[s_execflag_3] \n"
|
||||
_UK_ATOMIC_ADD_ " %[v_os_o3], %[c3], s[8:9] \n"
|
||||
" s_mov_b64 exec, %[s_execflag_4] \n"
|
||||
_UK_ATOMIC_ADD_ " %[v_os_o4], %[c4], s[8:9] \n"
|
||||
" s_mov_b64 exec, %[s_execflag_5] \n"
|
||||
_UK_ATOMIC_ADD_ " %[v_os_o5], %[c5], s[8:9] \n"
|
||||
" s_mov_b64 exec, %[s_execflag_6] \n"
|
||||
_UK_ATOMIC_ADD_ " %[v_os_o6], %[c6], s[8:9] \n"
|
||||
" s_mov_b64 exec, %[s_execflag_7] \n"
|
||||
_UK_ATOMIC_ADD_ " %[v_os_o7], %[c7], s[8:9] \n"
|
||||
" s_mov_b64 exec, s[38:39] \n"
|
||||
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k-- \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
|
||||
" s_cbranch_scc0 L_end%= \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
|
||||
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
|
||||
" s_add_u32 s12, s86, s12 \n"
|
||||
" s_addc_u32 s13, 0, s13 \n"
|
||||
" s_add_u32 s8, %[s_tile_os_o], s8 \n"
|
||||
" s_addc_u32 s9, 0, s9 \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
" s_barrier \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[128:129], v[128:129], 0 \n"
|
||||
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[12:15], 0 offen \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[130:131], v[130:131], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[132:133], v[132:133], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[134:135], v[134:135], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[136:137], v[136:137], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[12:15], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[138:139], v[138:139], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[140:141], v[140:141], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[142:143], v[142:143], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[128:129], v[192:193], 0 \n"
|
||||
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[12:15], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[130:131], v[194:195], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[132:133], v[196:197], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[134:135], v[198:199], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[136:137], v[200:201], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[12:15], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[138:139], v[202:203], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[140:141], v[204:205], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[142:143], v[206:207], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[144:145], v[128:129], 0 \n"
|
||||
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[12:15], 0 offen \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[146:147], v[130:131], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[148:149], v[132:133], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[150:151], v[134:135], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[152:153], v[136:137], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[12:15], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[154:155], v[138:139], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[156:157], v[140:141], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[158:159], v[142:143], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[144:145], v[192:193], 0 \n"
|
||||
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[12:15], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[146:147], v[194:195], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[148:149], v[196:197], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[150:151], v[198:199], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[152:153], v[200:201], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[12:15], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[154:155], v[202:203], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[156:157], v[204:205], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[158:159], v[206:207], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[160:161], v[144:145], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[12:15], 0 offen \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[162:163], v[146:147], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[164:165], v[148:149], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[166:167], v[150:151], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[168:169], v[152:153], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[12:15], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[170:171], v[154:155], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[172:173], v[156:157], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[174:175], v[158:159], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[160:161], v[208:209], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[12:15], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[162:163], v[210:211], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[164:165], v[212:213], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[166:167], v[214:215], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[168:169], v[216:217], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[12:15], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[170:171], v[218:219], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[172:173], v[220:221], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[174:175], v[222:223], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[176:177], v[144:145], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[12:15], 0 offen \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[178:179], v[146:147], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[180:181], v[148:149], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[182:183], v[150:151], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[184:185], v[152:153], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[12:15], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[186:187], v[154:155], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[188:189], v[156:157], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[190:191], v[158:159], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[176:177], v[208:209], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[12:15], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[178:179], v[210:211], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[180:181], v[212:213], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[182:183], v[214:215], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[184:185], v[216:217], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[12:15], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[186:187], v[218:219], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[188:189], v[220:221], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[190:191], v[222:223], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[192:193], v[160:161], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[12:15], 0 offen \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[194:195], v[162:163], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[196:197], v[164:165], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[198:199], v[166:167], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[200:201], v[168:169], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[12:15], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[202:203], v[170:171], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[204:205], v[172:173], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[206:207], v[174:175], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[192:193], v[224:225], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[12:15], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[194:195], v[226:227], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[196:197], v[228:229], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[198:199], v[230:231], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[200:201], v[232:233], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[12:15], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[202:203], v[234:235], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[204:205], v[236:237], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[206:207], v[238:239], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[208:209], v[160:161], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[12:15], 0 offen \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[210:211], v[162:163], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[212:213], v[164:165], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[214:215], v[166:167], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[216:217], v[168:169], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[12:15], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[218:219], v[170:171], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[220:221], v[172:173], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[222:223], v[174:175], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[208:209], v[224:225], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[12:15], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[210:211], v[226:227], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[212:213], v[228:229], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[214:215], v[230:231], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[216:217], v[232:233], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[12:15], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[218:219], v[234:235], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[220:221], v[236:237], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[222:223], v[238:239], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[224:225], v[176:177], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[12:15], 0 offen \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[226:227], v[178:179], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[228:229], v[180:181], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[230:231], v[182:183], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[232:233], v[184:185], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[12:15], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[234:235], v[186:187], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[236:237], v[188:189], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c16],%[c17],%[c18],%[c19]], acc[238:239], v[190:191], [%[c16],%[c17],%[c18],%[c19]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[224:225], v[240:241], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[12:15], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[226:227], v[242:243], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[228:229], v[244:245], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[230:231], v[246:247], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[232:233], v[248:249], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[12:15], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[234:235], v[250:251], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[236:237], v[252:253], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c20],%[c21],%[c22],%[c23]], acc[238:239], v[254:255], [%[c20],%[c21],%[c22],%[c23]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[240:241], v[176:177], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[12:15], 0 offen \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[242:243], v[178:179], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[244:245], v[180:181], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[246:247], v[182:183], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[248:249], v[184:185], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[12:15], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[250:251], v[186:187], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[252:253], v[188:189], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c24],%[c25],%[c26],%[c27]], acc[254:255], v[190:191], [%[c24],%[c25],%[c26],%[c27]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[240:241], v[240:241], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[242:243], v[242:243], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[244:245], v[244:245], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[246:247], v[246:247], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[248:249], v[248:249], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[250:251], v[250:251], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[252:253], v[252:253], [%[c28],%[c29],%[c30],%[c31]] \n"
|
||||
_UK_MFMA_ " [%[c28],%[c29],%[c30],%[c31]], acc[254:255], v[254:255], [%[c28],%[c29],%[c30],%[c31]]\n"
|
||||
" v_mul_f32 %[c16], %[scale_0], %[c16] \n"
|
||||
" v_mul_f32 %[c17], %[scale_0], %[c17] \n"
|
||||
" v_mul_f32 %[c18], %[scale_0], %[c18] \n"
|
||||
" v_mul_f32 %[c19], %[scale_0], %[c19] \n"
|
||||
" v_mul_f32 %[c20], %[scale_1], %[c20] \n"
|
||||
" v_mul_f32 %[c21], %[scale_1], %[c21] \n"
|
||||
" v_mul_f32 %[c22], %[scale_1], %[c22] \n"
|
||||
" v_mul_f32 %[c23], %[scale_1], %[c23] \n"
|
||||
" v_mul_f32 %[c24], %[scale_0], %[c24] \n"
|
||||
" v_mul_f32 %[c25], %[scale_0], %[c25] \n"
|
||||
" v_mul_f32 %[c26], %[scale_0], %[c26] \n"
|
||||
" v_mul_f32 %[c27], %[scale_0], %[c27] \n"
|
||||
" v_mul_f32 %[c28], %[scale_1], %[c28] \n"
|
||||
" v_mul_f32 %[c29], %[scale_1], %[c29] \n"
|
||||
" v_mul_f32 %[c30], %[scale_1], %[c30] \n"
|
||||
" v_mul_f32 %[c31], %[scale_1], %[c31] \n"
|
||||
|
||||
_UK_PK_CVT_("%[c16]", "%[c17]", "%[c16]")
|
||||
_UK_PK_CVT_("%[c18]", "%[c19]", "%[c17]")
|
||||
_UK_PK_CVT_("%[c20]", "%[c21]", "%[c18]")
|
||||
_UK_PK_CVT_("%[c22]", "%[c23]", "%[c19]")
|
||||
_UK_PK_CVT_("%[c24]", "%[c25]", "%[c20]")
|
||||
_UK_PK_CVT_("%[c26]", "%[c27]", "%[c21]")
|
||||
_UK_PK_CVT_("%[c28]", "%[c29]", "%[c22]")
|
||||
_UK_PK_CVT_("%[c30]", "%[c31]", "%[c23]")
|
||||
|
||||
" ;------------------------------ \n"
|
||||
" ds_write_b64 %[v_sfl_sst], [%[c16],%[c17]] offset:0 + %[shfl_base] \n"
|
||||
" ds_write_b64 %[v_sfl_sst], [%[c18],%[c19]] offset:4352 + %[shfl_base] \n"
|
||||
" ds_write_b64 %[v_sfl_sst], [%[c20],%[c21]] offset:2176 + %[shfl_base] \n"
|
||||
" ds_write_b64 %[v_sfl_sst], [%[c22],%[c23]] offset:6528 + %[shfl_base] \n"
|
||||
" s_waitcnt lgkmcnt(0) \n"
|
||||
" s_barrier \n"
|
||||
" ds_read_b32 %[c16], %[v_sfl_sld] offset:0 + %[shfl_base] \n"
|
||||
" ds_read_b32 %[c17], %[v_sfl_sld] offset:32 + %[shfl_base] \n"
|
||||
" ds_read_b32 %[c18], %[v_sfl_sld] offset:64 + %[shfl_base] \n"
|
||||
" ds_read_b32 %[c19], %[v_sfl_sld] offset:96 + %[shfl_base] \n"
|
||||
" ds_read_b32 %[c20], %[v_sfl_sld] offset:4352 + %[shfl_base] \n"
|
||||
" ds_read_b32 %[c21], %[v_sfl_sld] offset:4384 + %[shfl_base] \n"
|
||||
" ds_read_b32 %[c22], %[v_sfl_sld] offset:4416 + %[shfl_base] \n"
|
||||
" ds_read_b32 %[c23], %[v_sfl_sld] offset:4448 + %[shfl_base] \n"
|
||||
" s_waitcnt lgkmcnt(0) \n"
|
||||
" s_mov_b64 exec, %[s_execflag_0] \n"
|
||||
_UK_ATOMIC_ADD_ " %[v_os_o0], %[c16], s[8:9] \n"
|
||||
" s_mov_b64 exec, %[s_execflag_1] \n"
|
||||
_UK_ATOMIC_ADD_ " %[v_os_o1], %[c17], s[8:9] \n"
|
||||
" s_mov_b64 exec, %[s_execflag_2] \n"
|
||||
_UK_ATOMIC_ADD_ " %[v_os_o2], %[c18], s[8:9] \n"
|
||||
" s_mov_b64 exec, %[s_execflag_3] \n"
|
||||
_UK_ATOMIC_ADD_ " %[v_os_o3], %[c19], s[8:9] \n"
|
||||
" s_mov_b64 exec, %[s_execflag_4] \n"
|
||||
_UK_ATOMIC_ADD_ " %[v_os_o4], %[c20], s[8:9] \n"
|
||||
" s_mov_b64 exec, %[s_execflag_5] \n"
|
||||
_UK_ATOMIC_ADD_ " %[v_os_o5], %[c21], s[8:9] \n"
|
||||
" s_mov_b64 exec, %[s_execflag_6] \n"
|
||||
_UK_ATOMIC_ADD_ " %[v_os_o6], %[c22], s[8:9] \n"
|
||||
" s_mov_b64 exec, %[s_execflag_7] \n"
|
||||
_UK_ATOMIC_ADD_ " %[v_os_o7], %[c23], s[8:9] \n"
|
||||
" s_mov_b64 exec, s[38:39] \n"
|
||||
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k-- \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
|
||||
" s_cbranch_scc0 L_end%= \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
|
||||
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
|
||||
" s_add_u32 s12, s86, s12 \n"
|
||||
" s_addc_u32 s13, 0, s13 \n"
|
||||
" s_add_u32 s8, %[s_tile_os_o], s8 \n"
|
||||
" s_addc_u32 s9, 0, s9 \n"
|
||||
" s_branch L_start%= \n"
|
||||
"L_end%=: \n"
|
||||
|
||||
#undef _UK_MFMA_
|
||||
#undef _UK_PK_CVT_
|
||||
#undef _UK_ATOMIC_ADD_
|
||||
@@ -0,0 +1,516 @@
|
||||
#ifndef CK_TILE_FLATMM_UK_MFMA
|
||||
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
#endif
|
||||
|
||||
#if CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_BF16
|
||||
#define _UK_MFMA_ "v_mfma_f32_16x16x16_bf16"
|
||||
#elif CK_TILE_FLATMM_UK_MFMA == CK_TILE_FLATMM_UK_MFMA_FP16
|
||||
#define _UK_MFMA_ "v_mfma_f32_16x16x16_f16"
|
||||
#endif
|
||||
|
||||
"s_mov_b32 s16, %[s_res_a0] \n"
|
||||
"s_mov_b32 s17, %[s_res_a1] \n"
|
||||
"s_mov_b32 s18, %[s_res_a2] \n"
|
||||
"s_mov_b32 s19, %[s_res_a3] \n"
|
||||
"s_mov_b32 s20, %[s_res_b0] \n"
|
||||
"s_mov_b32 s21, %[s_res_b1] \n"
|
||||
"s_mov_b32 s22, %[s_res_b2] \n"
|
||||
"s_mov_b32 s23, %[s_res_b3] \n"
|
||||
// "s_nop 4\n"
|
||||
"; -- prefetch A0\n"
|
||||
"s_add_u32 m0, 0, %[s_m0_init] \n"
|
||||
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[smem_sz], %[s_m0_init] \n"
|
||||
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move a with cond \n"
|
||||
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n"
|
||||
"s_add_u32 s16, s86, s16 ; move a with cond \n"
|
||||
"s_addc_u32 s17, 0, s17 ; move a with cond \n"
|
||||
"; -- prefetch A1\n"
|
||||
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
|
||||
"s_add_u32 m0, 0, %[s_m0_init] \n"
|
||||
"s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
|
||||
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n"
|
||||
"s_add_u32 s16, s86, s16 ; move a with cond \n"
|
||||
"s_addc_u32 s17, 0, s17 ; move a with cond \n"
|
||||
"; -- prefetch B0\n"
|
||||
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n"
|
||||
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
|
||||
"buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
|
||||
"buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
|
||||
"buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n"
|
||||
"buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
|
||||
"buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
|
||||
"buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
|
||||
"buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n"
|
||||
"buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
|
||||
"buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
|
||||
"buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
|
||||
"buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n"
|
||||
"buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
|
||||
"buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
|
||||
"buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
|
||||
"buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n"
|
||||
"buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
|
||||
"buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
|
||||
"buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
|
||||
"buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n"
|
||||
"buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
|
||||
"buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
|
||||
"buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
|
||||
"buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n"
|
||||
"buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
|
||||
"buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
|
||||
"buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
|
||||
"buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n"
|
||||
"buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
|
||||
"buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
|
||||
"buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
|
||||
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
|
||||
"s_cselect_b32 s86, %[s_tile_os_b], 0 ; move b with cond \n"
|
||||
"s_add_u32 s20, s86, s20 ; move b with cond \n"
|
||||
"s_addc_u32 s21, 0, s21 ; move b with cond \n"
|
||||
"s_waitcnt vmcnt(40) \n"
|
||||
"s_barrier \n"
|
||||
"ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]\n" // 1024: N stride, 64 K stride
|
||||
"ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]\n"
|
||||
"ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]\n"
|
||||
"ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]\n"
|
||||
"ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]\n"
|
||||
"ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]\n"
|
||||
"ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]\n"
|
||||
"ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]\n"
|
||||
"L_start%=: \n"
|
||||
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
|
||||
" s_barrier \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[0:1], v[64:65], %[v_acc_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[2:3], v[66:67], %[v_acc_0] \n"
|
||||
" buffer_load_dwordx4 acc[128:131], %[v_os_b0], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[4:5], v[68:69], %[v_acc_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[6:7], v[70:71], %[v_acc_0] \n"
|
||||
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[8:9], v[72:73], %[v_acc_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[10:11], v[74:75], %[v_acc_0] \n"
|
||||
" buffer_load_dwordx4 acc[132:135], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[12:13], v[76:77], %[v_acc_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[14:15], v[78:79], %[v_acc_0] \n"
|
||||
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[0:1], v[80:81], %[v_acc_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[2:3], v[82:83], %[v_acc_1] \n"
|
||||
" buffer_load_dwordx4 acc[136:139], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[4:5], v[84:85], %[v_acc_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[6:7], v[86:87], %[v_acc_1] \n"
|
||||
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[8:9], v[88:89], %[v_acc_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[10:11], v[90:91], %[v_acc_1] \n"
|
||||
" buffer_load_dwordx4 acc[140:143], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[12:13], v[92:93], %[v_acc_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[14:15], v[94:95], %[v_acc_1] \n"
|
||||
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[16:17], v[64:65], %[v_acc_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[18:19], v[66:67], %[v_acc_2] \n"
|
||||
" buffer_load_dwordx4 acc[144:147], %[v_os_b1], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[20:21], v[68:69], %[v_acc_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[22:23], v[70:71], %[v_acc_2] \n"
|
||||
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[24:25], v[72:73], %[v_acc_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[26:27], v[74:75], %[v_acc_2] \n"
|
||||
" buffer_load_dwordx4 acc[148:151], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[28:29], v[76:77], %[v_acc_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[30:31], v[78:79], %[v_acc_2] \n"
|
||||
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[16:17], v[80:81], %[v_acc_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[18:19], v[82:83], %[v_acc_3] \n"
|
||||
" buffer_load_dwordx4 acc[152:155], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[20:21], v[84:85], %[v_acc_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[22:23], v[86:87], %[v_acc_3] \n"
|
||||
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[24:25], v[88:89], %[v_acc_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[26:27], v[90:91], %[v_acc_3] \n"
|
||||
" buffer_load_dwordx4 acc[156:159], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[28:29], v[92:93], %[v_acc_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[30:31], v[94:95], %[v_acc_3] \n"
|
||||
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[smem_sz], %[s_m0_init] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[32:33], v[64:65], %[v_acc_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[34:35], v[66:67], %[v_acc_4] \n"
|
||||
" buffer_load_dwordx4 acc[160:163], %[v_os_b2], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[36:37], v[68:69], %[v_acc_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[38:39], v[70:71], %[v_acc_4] \n"
|
||||
" ds_read_b128 v[96:99], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[40:41], v[72:73], %[v_acc_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[42:43], v[74:75], %[v_acc_4] \n"
|
||||
" buffer_load_dwordx4 acc[164:167], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[44:45], v[76:77], %[v_acc_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[46:47], v[78:79], %[v_acc_4] \n"
|
||||
" ds_read_b128 v[100:103], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[32:33], v[80:81], %[v_acc_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[34:35], v[82:83], %[v_acc_5] \n"
|
||||
" buffer_load_dwordx4 acc[168:171], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[36:37], v[84:85], %[v_acc_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[38:39], v[86:87], %[v_acc_5] \n"
|
||||
" ds_read_b128 v[104:107], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[40:41], v[88:89], %[v_acc_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[42:43], v[90:91], %[v_acc_5] \n"
|
||||
" buffer_load_dwordx4 acc[172:175], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[44:45], v[92:93], %[v_acc_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[46:47], v[94:95], %[v_acc_5] \n"
|
||||
" ds_read_b128 v[108:111], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[48:49], v[64:65], %[v_acc_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[50:51], v[66:67], %[v_acc_6] \n"
|
||||
" buffer_load_dwordx4 acc[176:179], %[v_os_b3], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[52:53], v[68:69], %[v_acc_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[54:55], v[70:71], %[v_acc_6] \n"
|
||||
" ds_read_b128 v[112:115], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[56:57], v[72:73], %[v_acc_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[58:59], v[74:75], %[v_acc_6] \n"
|
||||
" buffer_load_dwordx4 acc[180:183], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[60:61], v[76:77], %[v_acc_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[62:63], v[78:79], %[v_acc_6] \n"
|
||||
" ds_read_b128 v[116:119], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[48:49], v[80:81], %[v_acc_7] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[50:51], v[82:83], %[v_acc_7] \n"
|
||||
" buffer_load_dwordx4 acc[184:187], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[52:53], v[84:85], %[v_acc_7] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[54:55], v[86:87], %[v_acc_7] \n"
|
||||
" ds_read_b128 v[120:123], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[56:57], v[88:89], %[v_acc_7] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[58:59], v[90:91], %[v_acc_7] \n"
|
||||
" buffer_load_dwordx4 acc[188:191], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[60:61], v[92:93], %[v_acc_7] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[62:63], v[94:95], %[v_acc_7] \n"
|
||||
" ds_read_b128 v[124:127], %[v_os_slda], offset:1*%[smem_sz] + %[sld_os_7] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[64:65], v[64:65], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[66:67], v[66:67], %[v_acc_8] \n"
|
||||
" buffer_load_dwordx4 acc[192:195], %[v_os_b4], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[68:69], v[68:69], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[70:71], v[70:71], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[72:73], v[72:73], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[74:75], v[74:75], %[v_acc_8] \n"
|
||||
" buffer_load_dwordx4 acc[196:199], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[76:77], v[76:77], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[78:79], v[78:79], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[64:65], v[80:81], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[66:67], v[82:83], %[v_acc_9] \n"
|
||||
" buffer_load_dwordx4 acc[200:203], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[68:69], v[84:85], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[70:71], v[86:87], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[72:73], v[88:89], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[74:75], v[90:91], %[v_acc_9] \n"
|
||||
" buffer_load_dwordx4 acc[204:207], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[76:77], v[92:93], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[78:79], v[94:95], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[80:81], v[64:65], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[82:83], v[66:67], %[v_acc_10] \n"
|
||||
" buffer_load_dwordx4 acc[208:211], %[v_os_b5], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[84:85], v[68:69], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[86:87], v[70:71], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[88:89], v[72:73], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[90:91], v[74:75], %[v_acc_10] \n"
|
||||
" buffer_load_dwordx4 acc[212:215], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[92:93], v[76:77], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[94:95], v[78:79], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[80:81], v[80:81], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[82:83], v[82:83], %[v_acc_11] \n"
|
||||
" buffer_load_dwordx4 acc[216:219], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[84:85], v[84:85], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[86:87], v[86:87], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[88:89], v[88:89], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[90:91], v[90:91], %[v_acc_11] \n"
|
||||
" buffer_load_dwordx4 acc[220:223], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[92:93], v[92:93], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[94:95], v[94:95], %[v_acc_11] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[96:97], v[64:65], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[98:99], v[66:67], %[v_acc_12] \n"
|
||||
" buffer_load_dwordx4 acc[224:227], %[v_os_b6], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[100:101], v[68:69], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[102:103], v[70:71], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[104:105], v[72:73], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[106:107], v[74:75], %[v_acc_12] \n"
|
||||
" buffer_load_dwordx4 acc[228:231], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[108:109], v[76:77], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[110:111], v[78:79], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[96:97], v[80:81], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[98:99], v[82:83], %[v_acc_13] \n"
|
||||
" buffer_load_dwordx4 acc[232:235], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[100:101], v[84:85], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[102:103], v[86:87], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[104:105], v[88:89], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[106:107], v[90:91], %[v_acc_13] \n"
|
||||
" buffer_load_dwordx4 acc[236:239], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[108:109], v[92:93], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[110:111], v[94:95], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[112:113], v[64:65], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[114:115], v[66:67], %[v_acc_14] \n"
|
||||
" buffer_load_dwordx4 acc[240:243], %[v_os_b7], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[116:117], v[68:69], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[118:119], v[70:71], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[120:121], v[72:73], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[122:123], v[74:75], %[v_acc_14] \n"
|
||||
" buffer_load_dwordx4 acc[244:247], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[124:125], v[76:77], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[126:127], v[78:79], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[112:113], v[80:81], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[114:115], v[82:83], %[v_acc_15] \n"
|
||||
" buffer_load_dwordx4 acc[248:251], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[116:117], v[84:85], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[118:119], v[86:87], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[120:121], v[88:89], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[122:123], v[90:91], %[v_acc_15] \n"
|
||||
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[20:23], 0 offen offset:3072\n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[124:125], v[92:93], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[126:127], v[94:95], %[v_acc_15] \n"
|
||||
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
|
||||
" s_cbranch_scc0 L_end%= \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
|
||||
" s_cselect_b32 s86, %[s_tile_os_a], 0 \n"
|
||||
" s_add_u32 s16, s86, s16 \n"
|
||||
" s_addc_u32 s17, 0, s17 \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
|
||||
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
|
||||
" s_add_u32 s20, s86, s20 \n"
|
||||
" s_addc_u32 s21, 0, s21 \n"
|
||||
" ;------------------------------------------ \n"
|
||||
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
|
||||
" s_barrier \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[128:129], v[96:97], %[v_acc_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[130:131], v[98:99], %[v_acc_0] \n"
|
||||
" buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[132:133], v[100:101], %[v_acc_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[134:135], v[102:103], %[v_acc_0] \n"
|
||||
" buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[136:137], v[104:105], %[v_acc_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[138:139], v[106:107], %[v_acc_0] \n"
|
||||
" buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[140:141], v[108:109], %[v_acc_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_0], acc[142:143], v[110:111], %[v_acc_0] \n"
|
||||
" buffer_load_dword %[v_os_a1], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[128:129], v[112:113], %[v_acc_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[130:131], v[114:115], %[v_acc_1] \n"
|
||||
" buffer_load_dwordx4 acc[8:11], %[v_os_b0], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[132:133], v[116:117], %[v_acc_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[134:135], v[118:119], %[v_acc_1] \n"
|
||||
" buffer_load_dword %[v_os_a2], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[136:137], v[120:121], %[v_acc_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[138:139], v[122:123], %[v_acc_1] \n"
|
||||
" buffer_load_dwordx4 acc[12:15], %[v_os_b0], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[140:141], v[124:125], %[v_acc_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_1], acc[142:143], v[126:127], %[v_acc_1] \n"
|
||||
" buffer_load_dword %[v_os_a3], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[144:145], v[96:97], %[v_acc_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[146:147], v[98:99], %[v_acc_2] \n"
|
||||
" buffer_load_dwordx4 acc[16:19], %[v_os_b1], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[148:149], v[100:101], %[v_acc_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[150:151], v[102:103], %[v_acc_2] \n"
|
||||
" buffer_load_dword %[v_os_a4], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[152:153], v[104:105], %[v_acc_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[154:155], v[106:107], %[v_acc_2] \n"
|
||||
" buffer_load_dwordx4 acc[20:23], %[v_os_b1], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[156:157], v[108:109], %[v_acc_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_2], acc[158:159], v[110:111], %[v_acc_2] \n"
|
||||
" buffer_load_dword %[v_os_a5], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[144:145], v[112:113], %[v_acc_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[146:147], v[114:115], %[v_acc_3] \n"
|
||||
" buffer_load_dwordx4 acc[24:27], %[v_os_b1], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[148:149], v[116:117], %[v_acc_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[150:151], v[118:119], %[v_acc_3] \n"
|
||||
" buffer_load_dword %[v_os_a6], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, %[s_size_per_issue], m0 \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[152:153], v[120:121], %[v_acc_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[154:155], v[122:123], %[v_acc_3] \n"
|
||||
" buffer_load_dwordx4 acc[28:31], %[v_os_b1], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[156:157], v[124:125], %[v_acc_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_3], acc[158:159], v[126:127], %[v_acc_3] \n"
|
||||
" buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
|
||||
" s_add_u32 m0, 0, %[s_m0_init] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[160:161], v[96:97], %[v_acc_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[162:163], v[98:99], %[v_acc_4] \n"
|
||||
" buffer_load_dwordx4 acc[32:35], %[v_os_b2], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[164:165], v[100:101], %[v_acc_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[166:167], v[102:103], %[v_acc_4] \n"
|
||||
" ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[168:169], v[104:105], %[v_acc_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[170:171], v[106:107], %[v_acc_4] \n"
|
||||
" buffer_load_dwordx4 acc[36:39], %[v_os_b2], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[172:173], v[108:109], %[v_acc_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_4], acc[174:175], v[110:111], %[v_acc_4] \n"
|
||||
" ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[160:161], v[112:113], %[v_acc_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[162:163], v[114:115], %[v_acc_5] \n"
|
||||
" buffer_load_dwordx4 acc[40:43], %[v_os_b2], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[164:165], v[116:117], %[v_acc_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[166:167], v[118:119], %[v_acc_5] \n"
|
||||
" ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[168:169], v[120:121], %[v_acc_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[170:171], v[122:123], %[v_acc_5] \n"
|
||||
" buffer_load_dwordx4 acc[44:47], %[v_os_b2], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[172:173], v[124:125], %[v_acc_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_5], acc[174:175], v[126:127], %[v_acc_5] \n"
|
||||
" ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[176:177], v[96:97], %[v_acc_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[178:179], v[98:99], %[v_acc_6] \n"
|
||||
" buffer_load_dwordx4 acc[48:51], %[v_os_b3], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[180:181], v[100:101], %[v_acc_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[182:183], v[102:103], %[v_acc_6] \n"
|
||||
" ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[184:185], v[104:105], %[v_acc_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[186:187], v[106:107], %[v_acc_6] \n"
|
||||
" buffer_load_dwordx4 acc[52:55], %[v_os_b3], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[188:189], v[108:109], %[v_acc_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_6], acc[190:191], v[110:111], %[v_acc_6] \n"
|
||||
" ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[176:177], v[112:113], %[v_acc_7] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[178:179], v[114:115], %[v_acc_7] \n"
|
||||
" buffer_load_dwordx4 acc[56:59], %[v_os_b3], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[180:181], v[116:117], %[v_acc_7] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[182:183], v[118:119], %[v_acc_7] \n"
|
||||
" ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[184:185], v[120:121], %[v_acc_7] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[186:187], v[122:123], %[v_acc_7] \n"
|
||||
" buffer_load_dwordx4 acc[60:63], %[v_os_b3], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[188:189], v[124:125], %[v_acc_7] \n"
|
||||
_UK_MFMA_ " %[v_acc_7], acc[190:191], v[126:127], %[v_acc_7] \n"
|
||||
" ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[192:193], v[96:97], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[194:195], v[98:99], %[v_acc_8] \n"
|
||||
" buffer_load_dwordx4 acc[64:67], %[v_os_b4], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[196:197], v[100:101], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[198:199], v[102:103], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[200:201], v[104:105], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[202:203], v[106:107], %[v_acc_8] \n"
|
||||
" buffer_load_dwordx4 acc[68:71], %[v_os_b4], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[204:205], v[108:109], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_8], acc[206:207], v[110:111], %[v_acc_8] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[192:193], v[112:113], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[194:195], v[114:115], %[v_acc_9] \n"
|
||||
" buffer_load_dwordx4 acc[72:75], %[v_os_b4], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[196:197], v[116:117], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[198:199], v[118:119], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[200:201], v[120:121], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[202:203], v[122:123], %[v_acc_9] \n"
|
||||
" buffer_load_dwordx4 acc[76:79], %[v_os_b4], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[204:205], v[124:125], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_9], acc[206:207], v[126:127], %[v_acc_9] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[208:209], v[96:97], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[210:211], v[98:99], %[v_acc_10] \n"
|
||||
" buffer_load_dwordx4 acc[80:83], %[v_os_b5], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[212:213], v[100:101], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[214:215], v[102:103], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[216:217], v[104:105], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[218:219], v[106:107], %[v_acc_10] \n"
|
||||
" buffer_load_dwordx4 acc[84:87], %[v_os_b5], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[220:221], v[108:109], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_10], acc[222:223], v[110:111], %[v_acc_10] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[208:209], v[112:113], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[210:211], v[114:115], %[v_acc_11] \n"
|
||||
" buffer_load_dwordx4 acc[88:91], %[v_os_b5], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[212:213], v[116:117], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[214:215], v[118:119], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[216:217], v[120:121], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[218:219], v[122:123], %[v_acc_11] \n"
|
||||
" buffer_load_dwordx4 acc[92:95], %[v_os_b5], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[220:221], v[124:125], %[v_acc_11] \n"
|
||||
_UK_MFMA_ " %[v_acc_11], acc[222:223], v[126:127], %[v_acc_11] \n"
|
||||
" s_waitcnt vmcnt(32) \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[224:225], v[96:97], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[226:227], v[98:99], %[v_acc_12] \n"
|
||||
" buffer_load_dwordx4 acc[96:99], %[v_os_b6], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[228:229], v[100:101], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[230:231], v[102:103], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[232:233], v[104:105], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[234:235], v[106:107], %[v_acc_12] \n"
|
||||
" buffer_load_dwordx4 acc[100:103], %[v_os_b6], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[236:237], v[108:109], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_12], acc[238:239], v[110:111], %[v_acc_12] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[224:225], v[112:113], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[226:227], v[114:115], %[v_acc_13] \n"
|
||||
" buffer_load_dwordx4 acc[104:107], %[v_os_b6], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[228:229], v[116:117], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[230:231], v[118:119], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[232:233], v[120:121], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[234:235], v[122:123], %[v_acc_13] \n"
|
||||
" buffer_load_dwordx4 acc[108:111], %[v_os_b6], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[236:237], v[124:125], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_13], acc[238:239], v[126:127], %[v_acc_13] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[240:241], v[96:97], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[242:243], v[98:99], %[v_acc_14] \n"
|
||||
" buffer_load_dwordx4 acc[112:115], %[v_os_b7], s[20:23], 0 offen \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[244:245], v[100:101], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[246:247], v[102:103], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[248:249], v[104:105], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[250:251], v[106:107], %[v_acc_14] \n"
|
||||
" buffer_load_dwordx4 acc[116:119], %[v_os_b7], s[20:23], 0 offen offset:1024 \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[252:253], v[108:109], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_14], acc[254:255], v[110:111], %[v_acc_14] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[240:241], v[112:113], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[242:243], v[114:115], %[v_acc_15] \n"
|
||||
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[244:245], v[116:117], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[246:247], v[118:119], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[248:249], v[120:121], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[250:251], v[122:123], %[v_acc_15] \n"
|
||||
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[252:253], v[124:125], %[v_acc_15] \n"
|
||||
_UK_MFMA_ " %[v_acc_15], acc[254:255], v[126:127], %[v_acc_15] \n"
|
||||
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
|
||||
" s_cbranch_scc0 L_end%= \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
|
||||
" s_cselect_b32 s86, %[s_tile_os_a], 0 \n"
|
||||
" s_add_u32 s16, s86, s16 \n"
|
||||
" s_addc_u32 s17, 0, s17 \n"
|
||||
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
|
||||
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
|
||||
" s_add_u32 s20, s86, s20 \n"
|
||||
" s_addc_u32 s21, 0, s21 \n"
|
||||
" s_branch L_start%= \n"
|
||||
"L_end%=: \n"
|
||||
" s_nop 2 \n"
|
||||
|
||||
#undef _UK_MFMA_
|
||||
@@ -331,7 +331,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
|
||||
// prefetch K tile
|
||||
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
|
||||
async_load_tile_raw(
|
||||
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -355,6 +356,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
|
||||
k_dram_window,
|
||||
number<-1>{},
|
||||
k_oob_ck,
|
||||
k_pre_np);
|
||||
if constexpr(i_k0 < k0_loops - 1)
|
||||
@@ -386,7 +388,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
auto v_buf = load_tile(v_dram_window, bool_constant<false>{});
|
||||
auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
{ // tail
|
||||
gemm_0(s_acc,
|
||||
@@ -514,7 +516,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
move_tile_window(
|
||||
v_dram_window,
|
||||
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
|
||||
v_buf = load_tile(v_dram_window, bool_constant<false>{}); // load next v_buf
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
@@ -618,7 +621,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
|
||||
{
|
||||
v_buf = load_tile(v_dram_window, bool_constant<false>{}); // load next v_buf
|
||||
v_buf = load_tile(
|
||||
v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
|
||||
}
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
@@ -665,8 +669,11 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
if constexpr(k1_loops >= 2 &&
|
||||
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
|
||||
__builtin_amdgcn_s_barrier();
|
||||
async_load_tile_raw(
|
||||
k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, k_oob_ck, k_pre_np);
|
||||
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
|
||||
k_dram_window,
|
||||
number<-1>{},
|
||||
k_oob_ck,
|
||||
k_pre_np);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
}
|
||||
// tail
|
||||
|
||||
@@ -3,7 +3,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp"
|
||||
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
|
||||
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
|
||||
|
||||
421
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
Normal file
421
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp
Normal file
@@ -0,0 +1,421 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
// clang-format off
|
||||
// [indexing implementation-1]
|
||||
// using M_a as constexpr block_size to partition all tokens into different slices
|
||||
// each slice map to one expert, and one expert can have multiple slices
|
||||
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
|
||||
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
|
||||
// tok-0 tok-1 tok-2 tok-3 tok-4
|
||||
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
|
||||
//
|
||||
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
|
||||
// * this could be larger than actual, since actual tokens are on GPU
|
||||
//
|
||||
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
|
||||
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
|
||||
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
|
||||
//
|
||||
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
|
||||
//
|
||||
// * Note on token_id_per_expert/sorted_token_ids_ptr data:
|
||||
// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
|
||||
// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from
|
||||
// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
|
||||
//
|
||||
// 32bit 0........23 24.....31 bit
|
||||
// (data) -> (token_id | topk_id)
|
||||
// low 24 bit is for token id, top 8 bit is for topk id
|
||||
//
|
||||
// the input after smooth-quant is [token, topk, hidden_dim], originally it is [token, hidden_dim]
|
||||
// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
|
||||
//
|
||||
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
|
||||
// * length is (max_num_tokens_padded + block_size - 1) / block_size
|
||||
//
|
||||
// num_tokens_post_padded_ptr : [28]
|
||||
// num_sorted_tiles_ptr : [7]
|
||||
//
|
||||
// * different from vLLM
|
||||
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
|
||||
// 2)need sorted_weight_ptr
|
||||
// 3) use num_sorted_tiles_ptr, already divided by M_a
|
||||
//
|
||||
// * below used for indexing
|
||||
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
|
||||
// 2) sorted_weight_ptr
|
||||
// 3) sorted_expert_ids_ptr
|
||||
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
|
||||
//
|
||||
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
|
||||
//
|
||||
// [indexing implementation-2]
|
||||
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
|
||||
// tok-0 tok-1 tok-2 tok-3 tok-4
|
||||
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
|
||||
//
|
||||
// we generate original rol/col id as
|
||||
// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
|
||||
// let x be one element of above, we can get:
|
||||
// tpok_row_id(token_id) = x % num_tokens(5)
|
||||
// tpok_col_id(expert_Id) = x / num_tokens
|
||||
// topk_row_id/col_id can be used to access original topk_ids/topk_weight
|
||||
//
|
||||
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
|
||||
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
|
||||
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
|
||||
//
|
||||
// we can get permuted_rc_ids:
|
||||
// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
|
||||
//
|
||||
//
|
||||
// clang-format on
|
||||
//
|
||||
namespace ck_tile {
|
||||
|
||||
// m: num_tokens (or token*input-batch)
|
||||
// k: intermediate_size
|
||||
// n: intermediate_size used between 2 FC (TP slice this)
|
||||
// e: num expert
|
||||
// if doing pre-shuffle
|
||||
// nr : n / Block_Nr
|
||||
// kr : k / Block_Kr
|
||||
// w : fattened 1d wave buffer
|
||||
struct FusedMoeGemmHostArgs
|
||||
{
|
||||
const void* a_ptr; // [m, k], input token
|
||||
const void* a_scale_ptr; // [m, 1], token scale
|
||||
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
|
||||
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
|
||||
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
|
||||
const void* d_scale_ptr; // [e, 1, k], down scale
|
||||
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
|
||||
void* o_ptr; // [m, k], output token
|
||||
|
||||
const void* sorted_token_ids_ptr; // [max_num_tokens_padded]
|
||||
const void* sorted_weight_ptr; // [max_num_tokens_padded]
|
||||
const void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
|
||||
const void* num_sorted_tiles_ptr; // [1]
|
||||
|
||||
index_t hidden_size; // k
|
||||
index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
|
||||
index_t num_tokens; // input number of tokens for current iteration
|
||||
index_t num_experts; // number of groups
|
||||
index_t topk; // need this?
|
||||
|
||||
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
|
||||
};
|
||||
|
||||
// This is scatter/gather b2b group-gemm
|
||||
template <typename Partitioner_, typename Pipeline_, typename Epilogue_>
|
||||
struct FusedMoeGemmKernel
|
||||
{
|
||||
using Partitioner = remove_cvref_t<Partitioner_>;
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Epilogue = remove_cvref_t<Epilogue_>; // TODO: not used
|
||||
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
|
||||
// static_assert(kBlockPerCu > 0);
|
||||
|
||||
using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape
|
||||
static constexpr index_t BlockSize_ = BlockShape::BlockSize;
|
||||
|
||||
using ADataType = typename Pipeline::Problem::ADataType;
|
||||
using GDataType = typename Pipeline::Problem::GDataType;
|
||||
using DDataType = typename Pipeline::Problem::DDataType;
|
||||
using AccDataType = typename Pipeline::Problem::AccDataType;
|
||||
using ODataType = typename Pipeline::Problem::ODataType;
|
||||
using AScaleDataType = typename Pipeline::Problem::AScaleDataType;
|
||||
using GScaleDataType = typename Pipeline::Problem::GScaleDataType;
|
||||
using DScaleDataType = typename Pipeline::Problem::DScaleDataType;
|
||||
using YSmoothScaleDataType = typename Pipeline::Problem::YSmoothScaleDataType;
|
||||
using TopkWeightDataType = typename Pipeline::Problem::TopkWeightDataType;
|
||||
using IndexDataType = typename Pipeline::Problem::IndexDataType;
|
||||
using YDataType = typename Pipeline::Problem::YDataType;
|
||||
|
||||
using Traits = typename Pipeline::Problem::Traits;
|
||||
static constexpr bool UseUK = true;
|
||||
|
||||
static constexpr bool IsGateOnly = Traits::IsGateOnly;
|
||||
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
|
||||
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
|
||||
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
|
||||
template <> struct t2s<fp16_t> { static constexpr const char * name = "fp16"; };
|
||||
template <> struct t2s<bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
template <> struct t2s<int8_t> { static constexpr const char * name = "int8"; };
|
||||
// clang-format on
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
// clang-format off
|
||||
using S_ = BlockShape;
|
||||
|
||||
auto prec_str = [&] () {
|
||||
std::string base_str = _SS_(t2s<ADataType>::name);
|
||||
if (!std::is_same_v<ADataType, GDataType>) {
|
||||
base_str += _SS_("_") + _SS_(t2s<GDataType>::name);
|
||||
}
|
||||
return base_str;
|
||||
}();
|
||||
|
||||
return _SS_("fused_moe_") + _SS_(prec_str) + "_" +
|
||||
_TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" +
|
||||
_TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" +
|
||||
_TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name);
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
struct FusedMoeGemmKargs
|
||||
{
|
||||
const void* a_ptr; // [m, k], input token
|
||||
const void* a_scale_ptr; // [m, 1], token scale
|
||||
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
|
||||
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
|
||||
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
|
||||
const void* d_scale_ptr; // [e, 1, k], down scale
|
||||
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
|
||||
void* o_ptr; // [m, k], output token
|
||||
|
||||
const void* sorted_token_ids_ptr;
|
||||
const void* sorted_weight_ptr;
|
||||
const void* sorted_expert_ids_ptr;
|
||||
const void* num_sorted_tiles_ptr;
|
||||
|
||||
index_t hidden_size; // k
|
||||
index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
|
||||
index_t num_tokens; // input number of tokens for current iteration
|
||||
index_t num_experts; // number of groups
|
||||
index_t topk; // need this?
|
||||
|
||||
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
|
||||
};
|
||||
|
||||
// TODO: switch karg based on
|
||||
using Kargs = FusedMoeGemmKargs;
|
||||
using Hargs = FusedMoeGemmHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
|
||||
{
|
||||
// TODO: hargs/kargs not guranteed to be the same
|
||||
return bit_cast<Kargs>(hargs);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
|
||||
{
|
||||
constexpr index_t block_m = BlockShape::Block_M0;
|
||||
int max_num_tokens_padded =
|
||||
hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
|
||||
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
|
||||
return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
if constexpr(UseUK)
|
||||
{
|
||||
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
|
||||
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
|
||||
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
|
||||
|
||||
num_sorted_tiles = num_sorted_tiles / BlockShape::Block_M0;
|
||||
|
||||
const auto [sorted_tile_id, intermediate_tile_id] =
|
||||
Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
|
||||
// if(threadIdx.x == 0)
|
||||
// printf("bid:%d,%d, num_sorted_tiles:%d, sorted_tile_id:%d(%d),
|
||||
// intermediate_tile_id:%d\n", static_cast<int>(blockIdx.x),
|
||||
// static_cast<int>(blockIdx.y), num_sorted_tiles, sorted_tile_id, sorted_tile_id >=
|
||||
// num_sorted_tiles? 1 : 0, intermediate_tile_id);
|
||||
if(sorted_tile_id >= num_sorted_tiles)
|
||||
return;
|
||||
|
||||
Pipeline{}(kargs, smem, sorted_tile_id, intermediate_tile_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
// allocate LDS
|
||||
// __shared__ char smem_ptr[GetSmemSize()];
|
||||
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
|
||||
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
|
||||
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
|
||||
|
||||
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
|
||||
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
|
||||
index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
|
||||
index_t kr_1 =
|
||||
kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
|
||||
|
||||
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
|
||||
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
|
||||
|
||||
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
|
||||
|
||||
// note this is in unit of tile, need multiple tile size to get the index
|
||||
const auto [sorted_tile_id, intermediate_tile_id] =
|
||||
Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
|
||||
if(sorted_tile_id >= num_sorted_tiles)
|
||||
return;
|
||||
|
||||
const IndexDataType expert_id =
|
||||
__builtin_amdgcn_readfirstlane(reinterpret_cast<const IndexDataType*>(
|
||||
kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
|
||||
|
||||
// index along intermediate_size
|
||||
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
|
||||
// BlockShape::Block_N0);
|
||||
index_t interm_idx_nr =
|
||||
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
|
||||
|
||||
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
|
||||
const auto sorted_token_id =
|
||||
a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
|
||||
|
||||
index_t token_id =
|
||||
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
|
||||
auto topk_weight = reinterpret_cast<const TopkWeightDataType*>(
|
||||
kargs.sorted_weight_ptr)[sorted_token_id];
|
||||
|
||||
const auto a_window = [&]() {
|
||||
// A is already pre-padded in previous kernel
|
||||
const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr);
|
||||
const auto a_view_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.num_tokens, kargs.hidden_size),
|
||||
make_tuple(kargs.stride_token, 1),
|
||||
number<Pipeline::kAlignmentA>{},
|
||||
number<1>{});
|
||||
|
||||
// gather is here use indexing transform
|
||||
const auto a_gather_view_ = transform_tensor_view(
|
||||
a_view_,
|
||||
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
|
||||
make_pass_through_transform(kargs.hidden_size)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
const auto a_window_ = make_tile_window(
|
||||
a_gather_view_,
|
||||
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
|
||||
{0, 0});
|
||||
return a_window_;
|
||||
}();
|
||||
|
||||
// TODO: gtile using NSub to have less register pressure
|
||||
const auto g_window = [&]() {
|
||||
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
|
||||
static_cast<long_index_t>(expert_id) * expert_stride_0 +
|
||||
interm_idx_nr * kr_0 * BlockShape::Block_W0;
|
||||
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
g_ptr,
|
||||
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
|
||||
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
|
||||
number<Pipeline::kAlignmentG>{},
|
||||
number<1>{});
|
||||
const auto g_view_1_ =
|
||||
pad_tensor_view(g_view_,
|
||||
make_tuple(number<BlockShape::Block_Nr0>{},
|
||||
number<BlockShape::Block_Kr0>{},
|
||||
number<BlockShape::Block_W0>{}),
|
||||
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
|
||||
|
||||
const auto g_window_ = make_tile_window(g_view_1_,
|
||||
make_tuple(number<BlockShape::Block_Nr0>{},
|
||||
number<BlockShape::Block_Kr0>{},
|
||||
number<BlockShape::Block_W0>{}),
|
||||
{0, 0, 0});
|
||||
return g_window_;
|
||||
}();
|
||||
|
||||
const auto d_window = [&]() {
|
||||
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
|
||||
static_cast<long_index_t>(expert_id) * expert_stride_1 +
|
||||
interm_idx_nr * BlockShape::Block_W1;
|
||||
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
|
||||
|
||||
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
d_ptr,
|
||||
make_tuple(nr_1, kr_1, BlockShape::Block_W1),
|
||||
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
|
||||
number<Pipeline::kAlignmentD>{},
|
||||
number<1>{});
|
||||
const auto d_view_1_ =
|
||||
pad_tensor_view(d_view_,
|
||||
make_tuple(number<BlockShape::Block_Nr1>{},
|
||||
number<BlockShape::Block_Kr1>{},
|
||||
number<BlockShape::Block_W1>{}),
|
||||
sequence<PadHiddenSize, PadIntermediateSize, 0>{});
|
||||
|
||||
const auto d_window_ = make_tile_window(d_view_1_,
|
||||
make_tuple(number<BlockShape::Block_Nr1>{},
|
||||
number<BlockShape::Block_Kr1>{},
|
||||
number<BlockShape::Block_W1>{}),
|
||||
{0, 0, 0});
|
||||
return d_window_;
|
||||
}();
|
||||
|
||||
auto o_window = [&]() {
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
|
||||
auto o_view_ = make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::atomic_add>(
|
||||
o_ptr,
|
||||
make_tuple(kargs.num_tokens, kargs.hidden_size),
|
||||
make_tuple(kargs.stride_token, 1),
|
||||
number<Pipeline::kAlignmentO>{},
|
||||
number<1>{});
|
||||
|
||||
// gather is here
|
||||
auto o_scatter_view_ = transform_tensor_view(
|
||||
o_view_,
|
||||
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
|
||||
make_pass_through_transform(kargs.hidden_size)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
auto o_window_ = make_tile_window(
|
||||
o_scatter_view_,
|
||||
make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
|
||||
{0, 0});
|
||||
return o_window_;
|
||||
}();
|
||||
|
||||
// do compute yeah
|
||||
Pipeline{}(a_window,
|
||||
g_window,
|
||||
d_window,
|
||||
o_window,
|
||||
topk_weight,
|
||||
smem,
|
||||
kargs.hidden_size,
|
||||
kargs.intermediate_size,
|
||||
kargs.stride_token);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
125
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
Normal file
125
include/ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp
Normal file
@@ -0,0 +1,125 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
tensors:
|
||||
1. act (A): input feature map
|
||||
2. gate (G): B matrix for first gemm, output will do activation(Silu)
|
||||
3. up (U): B matrix for first gemm
|
||||
4. down (D): B matrix for second gemm
|
||||
N1
|
||||
/ \
|
||||
+----------+ |
|
||||
| Down | |
|
||||
x----------x |
|
||||
hidden hidden K1 | | |
|
||||
N0 N0 x----------x |
|
||||
| +------x-----x------+------x-----x------+ | | |
|
||||
dim | | Gate | | | Up | | | | | |
|
||||
contiguous | | | | | | | | | | |
|
||||
| | | | | | | | | | |
|
||||
v +------x-----x------+------x-----x------+ +----------+ V
|
||||
K0 | | | | | contiguous
|
||||
/ \ v v v v |
|
||||
+---------+ +------x-----x------+------x-----x------+ |
|
||||
M0 | A | | | | | | | | |
|
||||
+---------+ +------x-----x------+------x-----x------+ |
|
||||
----------> | | |
|
||||
contiguous | V V
|
||||
| x-----x +----------+
|
||||
+------------> M1 | Y | ---------> | Out(O) |
|
||||
ACT x-----x +----------+
|
||||
K1 = N0 dim
|
||||
|
||||
* Note: Act could be Gelu/Silu/...
|
||||
* Note: some model does not have Up
|
||||
*/
|
||||
template <typename BlockTile_0_,
|
||||
typename WarpPerBlock_0_,
|
||||
typename WarpTile_0_,
|
||||
typename BlockTile_1_,
|
||||
typename WarpPerBlock_1_,
|
||||
typename WarpTile_1_>
|
||||
struct FusedMoeGemmShape
|
||||
{
|
||||
using BlockTile_0 = remove_cvref_t<BlockTile_0_>;
|
||||
using WarpPerBlock_0 = remove_cvref_t<WarpPerBlock_0_>;
|
||||
using WarpTile_0 = remove_cvref_t<WarpTile_0_>;
|
||||
using BlockTile_1 = remove_cvref_t<BlockTile_1_>;
|
||||
using WarpPerBlock_1 = remove_cvref_t<WarpPerBlock_1_>;
|
||||
using WarpTile_1 = remove_cvref_t<WarpTile_1_>;
|
||||
|
||||
static constexpr index_t NumWarps =
|
||||
reduce_on_sequence(WarpPerBlock_0{}, multiplies{}, number<1>{});
|
||||
|
||||
// TODO: we don't support half warps aound to 1 warp here
|
||||
static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{}));
|
||||
|
||||
static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{});
|
||||
static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{});
|
||||
static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{});
|
||||
static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(number<0>{});
|
||||
static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(number<1>{});
|
||||
static constexpr index_t WarpPerBlock_K0 = WarpPerBlock_0::at(number<2>{});
|
||||
static constexpr index_t Warp_M0 = WarpTile_0::at(number<0>{});
|
||||
static constexpr index_t Warp_N0 = WarpTile_0::at(number<1>{});
|
||||
static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{});
|
||||
|
||||
static constexpr index_t ThreadPerBlock_M0 = Warp_M0 * WarpPerBlock_M0;
|
||||
static constexpr index_t ThreadPerBlock_N0 = Warp_N0 * WarpPerBlock_N0;
|
||||
static constexpr index_t ThreadPerBlock_K0 = Warp_K0 * WarpPerBlock_K0;
|
||||
static_assert(Block_M0 % ThreadPerBlock_M0 == 0);
|
||||
static_assert(Block_N0 % ThreadPerBlock_N0 == 0);
|
||||
static_assert(Block_K0 % ThreadPerBlock_K0 == 0);
|
||||
static constexpr index_t Repeat_M0 = Block_M0 / ThreadPerBlock_M0;
|
||||
static constexpr index_t Repeat_N0 = Block_N0 / ThreadPerBlock_N0;
|
||||
static constexpr index_t Repeat_K0 = Block_K0 / ThreadPerBlock_K0;
|
||||
|
||||
static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{});
|
||||
static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{});
|
||||
static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{});
|
||||
static constexpr index_t WarpPerBlock_M1 = WarpPerBlock_1::at(number<0>{});
|
||||
static constexpr index_t WarpPerBlock_N1 = WarpPerBlock_1::at(number<1>{});
|
||||
static constexpr index_t WarpPerBlock_K1 = WarpPerBlock_1::at(number<2>{});
|
||||
static constexpr index_t Warp_M1 = WarpTile_1::at(number<0>{});
|
||||
static constexpr index_t Warp_N1 = WarpTile_1::at(number<1>{});
|
||||
static constexpr index_t Warp_K1 = WarpTile_1::at(number<2>{});
|
||||
|
||||
static constexpr index_t ThreadPerBlock_M1 = Warp_M1 * WarpPerBlock_M1;
|
||||
static constexpr index_t ThreadPerBlock_N1 = Warp_N1 * WarpPerBlock_N1;
|
||||
static constexpr index_t ThreadPerBlock_K1 = Warp_K1 * WarpPerBlock_K1;
|
||||
static_assert(Block_M1 % ThreadPerBlock_M1 == 0);
|
||||
static_assert(Block_N1 % ThreadPerBlock_N1 == 0);
|
||||
static_assert(Block_K1 % ThreadPerBlock_K1 == 0);
|
||||
static constexpr index_t Repeat_M1 = Block_M1 / ThreadPerBlock_M1;
|
||||
static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1;
|
||||
static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1;
|
||||
|
||||
static constexpr index_t BlockSize = warpSize * NumWarps;
|
||||
|
||||
// some assert
|
||||
static_assert(Block_M0 == Block_M1);
|
||||
static_assert(Block_N0 == Block_K1 || (Block_N0 / 2) == Block_K1); // Gate Only or Gate+Up
|
||||
|
||||
// pre-shuffle tile size compute (assume only for B matrix)
|
||||
// we flatten the each wave tile to a 1d linear tensor(at model loading time)
|
||||
// e.g. originally we have Block_N*Block_K tile size, after pre-shuffle
|
||||
// we can have Block_Nr*Block_Kr*Block_W, where Block_W is Warp_N*Warp_K,
|
||||
// and Block_Nr=Block_N/Warp_N, Block_Kr=Block_K/Warp_K
|
||||
static constexpr index_t Block_W0 = Warp_N0 * Warp_K0;
|
||||
static constexpr index_t Block_Nr0 = Block_N0 / Warp_N0;
|
||||
static constexpr index_t Block_Kr0 = Block_K0 / Warp_K0;
|
||||
static constexpr index_t Block_W1 = Warp_N1 * Warp_K1;
|
||||
static constexpr index_t Block_Nr1 = Block_N1 / Warp_N1;
|
||||
static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1;
|
||||
|
||||
static_assert(Block_W0 == Block_W1);
|
||||
// static_assert(Block_Nr0 == Block_Kr1);
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockShape_>
|
||||
struct FusedMoeGemmTilePartitioner_Linear
|
||||
{
|
||||
// FusedMoeGemmShape
|
||||
using BlockShape = ck_tile::remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr const char* name = "lin";
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*num_sorted_tiles*/,
|
||||
ck_tile::index_t /*intermediate_size*/)
|
||||
{
|
||||
index_t i_n = blockIdx.x;
|
||||
index_t i_m = blockIdx.y;
|
||||
|
||||
return ck_tile::make_tuple(i_m, i_n);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t max_tokens, index_t intermediate_size)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
index_t ms = ck_tile::integer_divide_ceil(max_tokens, BlockShape::Block_M0);
|
||||
index_t ns = ck_tile::integer_divide_ceil(intermediate_size, BlockShape::Block_N0);
|
||||
return dim3(ns, ms, 1);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,651 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
|
||||
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
|
||||
|
||||
<----- gemm-N ------>
|
||||
+----+----+----+----+
|
||||
| w0 | w1 | w2 | w3 | gemm-m
|
||||
+----+----+----+----+
|
||||
*/
|
||||
template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
|
||||
struct FusedMoeGemmPipeline_FlatmmEx
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
|
||||
|
||||
using ADataType = typename Problem::ADataType;
|
||||
using GDataType = typename Problem::GDataType;
|
||||
using DDataType = typename Problem::DDataType;
|
||||
using AccDataType = typename Problem::AccDataType;
|
||||
using ODataType = typename Problem::ODataType;
|
||||
using AScaleDataType = typename Problem::AScaleDataType;
|
||||
using GScaleDataType = typename Problem::GScaleDataType;
|
||||
using DScaleDataType = typename Problem::DScaleDataType;
|
||||
using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
|
||||
using TopkWeightDataType = typename Problem::TopkWeightDataType;
|
||||
using IndexDataType = typename Problem::IndexDataType;
|
||||
using YDataType = typename Problem::YDataType;
|
||||
|
||||
using Traits = typename Problem::Traits;
|
||||
|
||||
static constexpr bool IsGateOnly = Traits::IsGateOnly;
|
||||
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
|
||||
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
|
||||
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
|
||||
|
||||
static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
|
||||
static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
|
||||
static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
|
||||
static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
|
||||
|
||||
static constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
|
||||
static constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
|
||||
static constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
|
||||
static constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
// minimize occupancy
|
||||
return 2;
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "fused_moe_flatmm";
|
||||
|
||||
// TODO: there are multiple buffers
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
|
||||
{
|
||||
return Policy::template GetSmemSize_A<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
// this is the thread-offset along row/col
|
||||
CK_TILE_HOST_DEVICE static auto GetACoord()
|
||||
{
|
||||
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
|
||||
const auto a_coord = a_dist.calculate_index();
|
||||
return a_coord;
|
||||
}
|
||||
|
||||
// this is the thread-offset along row/col
|
||||
CK_TILE_HOST_DEVICE static auto GetOCoord()
|
||||
{
|
||||
constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
|
||||
const auto o_coord = o_dist.calculate_index();
|
||||
return o_coord;
|
||||
}
|
||||
|
||||
template <typename AWindow, typename GWindow, typename DWindow, typename OWindow>
|
||||
CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
|
||||
const GWindow& g_window_,
|
||||
const DWindow& d_window_,
|
||||
OWindow& o_window_,
|
||||
TopkWeightDataType /*topk_weight*/,
|
||||
CK_TILE_LDS_ADDR void* smem,
|
||||
index_t hidden_size,
|
||||
index_t intermediate_size)
|
||||
{
|
||||
_Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wc++20-extensions\"");
|
||||
constexpr auto NEG1 = number<-1>{};
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto TRUE = bool_constant<true>{};
|
||||
constexpr auto FALSE = bool_constant<false>{};
|
||||
|
||||
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
|
||||
CK_TILE_LDS_ADDR ADataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(
|
||||
reinterpret_cast<CK_TILE_LDS_ADDR char*>(smem) +
|
||||
Policy::template GetSmemSize_A<Problem>());
|
||||
|
||||
auto g_view = g_window_.get_bottom_tensor_view();
|
||||
|
||||
auto u_view = [&]() {
|
||||
if constexpr(IsGateOnly)
|
||||
{
|
||||
return g_view;
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t nr_0 = intermediate_size / BlockShape::Block_Nr0;
|
||||
index_t kr_0 = hidden_size / BlockShape::Block_Kr0;
|
||||
|
||||
const GDataType* g_ptr =
|
||||
g_window_.get_bottom_tensor_view().get_buffer_view().p_data_;
|
||||
const GDataType* u_ptr = g_ptr + (nr_0 / 2) * kr_0 * number<BlockShape::Block_W0>{};
|
||||
|
||||
const auto u_view_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
u_ptr,
|
||||
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
|
||||
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
|
||||
number<kAlignmentG>{},
|
||||
number<1>{});
|
||||
const auto u_view_1_ =
|
||||
pad_tensor_view(u_view_,
|
||||
make_tuple(number<BlockShape::Block_Nr0>{},
|
||||
number<BlockShape::Block_Kr0>{},
|
||||
number<BlockShape::Block_W0>{}),
|
||||
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
|
||||
return u_view_1_;
|
||||
}
|
||||
}();
|
||||
|
||||
auto a_win = make_tile_window_linear(
|
||||
a_window_, Policy::template MakeGlobalTileDistribution_A<Problem>());
|
||||
auto g_win =
|
||||
make_tile_window_linear(g_window_,
|
||||
Policy::template MakeGlobalTileDistribution_G<Problem>(),
|
||||
sequence<0, 1, 1>{});
|
||||
auto d_win =
|
||||
make_tile_window_linear(d_window_,
|
||||
Policy::template MakeGlobalTileDistribution_D<Problem>(),
|
||||
sequence<0, 1, 1>{});
|
||||
auto o_win = make_tile_window_linear(
|
||||
o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
|
||||
|
||||
using g_thread_type = decltype(load_tile(g_win));
|
||||
using d_thread_type = decltype(load_tile(d_win));
|
||||
|
||||
using WarpGemm0 = decltype(Policy::template GetWarpGemm0<Problem>());
|
||||
using WarpGemm1 = decltype(Policy::template GetWarpGemm1<Problem>());
|
||||
auto warp_gemm_0 = WarpGemm0{};
|
||||
auto warp_gemm_1 = WarpGemm1{};
|
||||
|
||||
// issues_warps_lanes
|
||||
auto a_sst_win0 =
|
||||
make_tile_window(make_tensor_view<address_space_enum::lds>(
|
||||
smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()),
|
||||
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
|
||||
{0, 0, 0});
|
||||
|
||||
auto a_sst_win1 =
|
||||
make_tile_window(make_tensor_view<address_space_enum::lds>(
|
||||
smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()),
|
||||
Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
|
||||
{0, 0, 0});
|
||||
// m*k
|
||||
auto a_sld_win0 = [&]() {
|
||||
using WG = WarpGemm0;
|
||||
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<BlockShape::Repeat_M0, BlockShape::WarpPerBlock_M0>,
|
||||
sequence<BlockShape::Repeat_K0>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_outer_dstr_enc, typename WG::AWarpDstrEncoding{});
|
||||
return make_tile_window_linear(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
smem_0, Policy::template MakeLdsLoadDesc_A<Problem>()),
|
||||
Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(a_block_dstr_encode));
|
||||
}();
|
||||
|
||||
// m*k
|
||||
auto a_sld_win1 = [&]() {
|
||||
using WG = WarpGemm0;
|
||||
constexpr auto a_outer_dstr_enc = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<BlockShape::Repeat_M0, BlockShape::WarpPerBlock_M0>,
|
||||
sequence<BlockShape::Repeat_K0>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_outer_dstr_enc, typename WG::AWarpDstrEncoding{});
|
||||
return make_tile_window_linear(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
smem_1, Policy::template MakeLdsLoadDesc_A<Problem>()),
|
||||
Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(a_block_dstr_encode));
|
||||
}();
|
||||
|
||||
auto bridge_sst_win = [&]() {
|
||||
return make_tile_window(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<YDataType*>(smem),
|
||||
Policy::template MakeBridgeLdsStoreDesc<Problem>()),
|
||||
Policy::template MakeBridgeLdsStoreDesc<Problem>().get_lengths(),
|
||||
{0, 0});
|
||||
}();
|
||||
|
||||
auto bridge_sld_win = [&]() {
|
||||
return make_tile_window_linear(
|
||||
make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<YDataType*>(smem),
|
||||
Policy::template MakeBridgeLdsLoadDesc<Problem>()),
|
||||
Policy::template MakeBridgeLdsLoadDesc<Problem>().get_lengths(),
|
||||
{0, 0},
|
||||
Policy::template MakeYTileDistribution<Problem>());
|
||||
}();
|
||||
|
||||
// also OK with C array, 2 register buffer
|
||||
statically_indexed_array<g_thread_type, 2> gs;
|
||||
|
||||
constexpr auto issues_a = number<a_win.get_num_of_access()>{};
|
||||
constexpr auto issues_g = number<g_win.get_num_of_access()>{};
|
||||
// constexpr auto issues_d = number<d_win.get_num_of_access()>{};
|
||||
// constexpr auto issues_o = number<o_win.get_num_of_access()>{};
|
||||
constexpr auto issues_gemm0 =
|
||||
number<BlockShape::Repeat_M0 * BlockShape::Repeat_N0 * BlockShape::Repeat_K0 *
|
||||
warp_gemm_0.get_num_of_access()>{};
|
||||
constexpr auto issues_gemm1 =
|
||||
number<BlockShape::Repeat_M1 * BlockShape::Repeat_N1 * BlockShape::Repeat_K1 *
|
||||
warp_gemm_1.get_num_of_access()>{};
|
||||
// constexpr auto issues_sld_a = number<a_sld_win0.get_num_of_access()>{};
|
||||
|
||||
const index_t num_blocks_k0 =
|
||||
(hidden_size + BlockShape::Block_K0 - 1) / BlockShape::Block_K0;
|
||||
const index_t num_blocks_n1 =
|
||||
(hidden_size + BlockShape::Block_N1 - 1) / BlockShape::Block_N1;
|
||||
|
||||
using a_thread_type = decltype(load_tile(a_sld_win0));
|
||||
statically_indexed_array<a_thread_type, 2> as;
|
||||
|
||||
auto gld_a = [&]<typename PreNop = bool_constant<false>>(
|
||||
auto& a_store_, auto i_access, PreNop = {})
|
||||
{
|
||||
async_load_tile_raw(a_store_, a_win, i_access, PreNop{});
|
||||
};
|
||||
auto move_a = [&]() {
|
||||
move_tile_window(a_win, {number<0>{}, number<BlockShape::Block_K0>{}});
|
||||
};
|
||||
auto sld_a = [&](auto& a_, auto& win_, auto i_access) {
|
||||
load_tile_raw(a_, win_, i_access);
|
||||
};
|
||||
|
||||
auto gld_g = [&]<typename PreNop = bool_constant<false>>(
|
||||
auto& g_, auto i_access, PreNop = {})
|
||||
{
|
||||
if constexpr(IsGateOnly)
|
||||
{
|
||||
// TODO: hack!
|
||||
if constexpr(i_access.value == 0)
|
||||
{
|
||||
g_win.bottom_tensor_view_ = g_view;
|
||||
}
|
||||
else if constexpr(i_access.value == issues_g / 2)
|
||||
{
|
||||
g_win.bottom_tensor_view_ = u_view;
|
||||
}
|
||||
}
|
||||
load_tile_raw(g_, g_win, i_access, FALSE, PreNop{});
|
||||
};
|
||||
auto move_g = [&]() {
|
||||
move_tile_window(g_win, {number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}});
|
||||
};
|
||||
statically_indexed_array<d_thread_type, 2> ds;
|
||||
|
||||
auto gld_d = [&]<typename PreNop = bool_constant<false>>(
|
||||
auto& d_, auto i_access, PreNop = {})
|
||||
{
|
||||
load_tile_raw(d_, d_win, i_access, FALSE, PreNop{});
|
||||
};
|
||||
auto move_d = [&]() {
|
||||
// d move along gemm-n
|
||||
move_tile_window(d_win, {number<BlockShape::Block_N1>{}, number<0>{}});
|
||||
};
|
||||
|
||||
auto atomic_add_o = [&]<typename PreNop = bool_constant<false>>(
|
||||
auto& o_, auto i_access, PreNop = {})
|
||||
{
|
||||
update_tile_raw(o_win, o_, i_access, TRUE, PreNop{});
|
||||
};
|
||||
|
||||
auto acc_0 = Policy::template MakeCBlockTile_Gemm0<Problem>();
|
||||
auto acc_1s = generate_tuple(
|
||||
[&](auto) { return Policy::template MakeCBlockTile_Gemm1<Problem>(); }, number<2>{});
|
||||
|
||||
// clang-format off
|
||||
auto gemm_0 = [&]<typename PostNop = bool_constant<false>>
|
||||
(auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
|
||||
using WarpGemm = remove_cvref_t<decltype(warp_gemm_0)>;
|
||||
|
||||
constexpr auto repeat_sub = WarpGemm::get_num_of_access();
|
||||
constexpr auto repeat_m = BlockShape::Repeat_M0;
|
||||
// constexpr auto repeat_n = BlockShape::Repeat_N0;
|
||||
constexpr auto repeat_k = BlockShape::Repeat_K0;
|
||||
// loop order n->m->k
|
||||
constexpr auto i_sub = i_access % repeat_sub;
|
||||
constexpr auto i_k = (i_access / repeat_sub) % repeat_k;
|
||||
constexpr auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
|
||||
constexpr auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
AWarpTensor w_a;
|
||||
w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
BWarpTensor w_b;
|
||||
w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<i_n, i_k>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
CWarpTensor w_c;
|
||||
w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
warp_gemm_0(w_c, w_a, w_b, number<i_sub>{}, PostNop{});
|
||||
|
||||
t_c.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
w_c.get_thread_buffer());
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
auto gemm_1 = [&]<typename PostNop = bool_constant<false>>
|
||||
(auto& t_c, auto& t_a, auto& t_b, auto i_access, PostNop = {}) {
|
||||
using WarpGemm = remove_cvref_t<decltype(warp_gemm_1)>;
|
||||
|
||||
constexpr auto repeat_sub = WarpGemm::get_num_of_access();
|
||||
constexpr auto repeat_m = BlockShape::Repeat_M0;
|
||||
// constexpr auto repeat_n = BlockShape::Repeat_N0;
|
||||
constexpr auto repeat_k = BlockShape::Repeat_K0;
|
||||
// loop order n->m->k
|
||||
constexpr auto i_sub = i_access % repeat_sub;
|
||||
constexpr auto i_k = (i_access / repeat_sub) % repeat_k;
|
||||
constexpr auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
|
||||
constexpr auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
constexpr auto a_warp_y_lengths = to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_warp_y_lengths = to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths = to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
AWarpTensor w_a;
|
||||
w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<i_m, i_k>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
BWarpTensor w_b;
|
||||
w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<i_n, i_k>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
CWarpTensor w_c;
|
||||
w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
warp_gemm_1(w_c, w_a, w_b, number<i_sub>{}, PostNop{});
|
||||
|
||||
t_c.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<i_m, i_n>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
w_c.get_thread_buffer());
|
||||
};
|
||||
// clang-format on
|
||||
_Pragma("clang diagnostic pop");
|
||||
|
||||
// this gemm pipeline is designed with assumption that issues of buffer-load/ds_read can
|
||||
// be hide under mfma. In other words, issues of mfma is >= memory this is true if we
|
||||
// pre-shuffle B matrix, and A matrix is relatively small we prefer use multiple mfma
|
||||
// paired with 1 buffer-load B matrix, to get max throughput of buffer_load. and by
|
||||
// preshuffle, we always pack to dwordx4 load, and this will already extend to multiple
|
||||
// mfma but that is already consumed inside warpgemm-impl. So indeed how many extra
|
||||
// mfma(that can reuse the B matrix) only affected by M repeat.
|
||||
auto pipeline_gemm0 = [&]() {
|
||||
constexpr index_t total_loops = issues_gemm0;
|
||||
constexpr auto sr = Policy::template GetSequencer_0<Problem>();
|
||||
static_assert(sr.size() == total_loops);
|
||||
|
||||
constexpr auto c_sld_a_0 = MAKE_SC();
|
||||
constexpr auto c_gld_a_0 = MAKE_SC();
|
||||
constexpr auto c_gld_b_0 = MAKE_SC();
|
||||
// compute buffer 1
|
||||
static_for<0, total_loops, 1>{}([&](auto i_issue) {
|
||||
gemm_0(acc_0, as[I0], gs[I0], i_issue);
|
||||
constexpr index_t slot = sr.at(i_issue);
|
||||
|
||||
if constexpr(slot & SLD_A)
|
||||
sld_a(as[I1], a_sld_win1, number<NEXT_SCI(c_sld_a_0, i_issue)>{});
|
||||
if constexpr(slot & GLD_A)
|
||||
gld_a(a_sst_win0, number<NEXT_SCI(c_gld_a_0, i_issue)>{});
|
||||
if constexpr(slot & GLD_B)
|
||||
gld_g(gs[I0], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
|
||||
});
|
||||
move_g();
|
||||
move_a();
|
||||
block_sync_load_raw(issues_a + issues_g);
|
||||
lds_load_fence();
|
||||
|
||||
constexpr auto c_sld_a_1 = MAKE_SC();
|
||||
constexpr auto c_gld_a_1 = MAKE_SC();
|
||||
constexpr auto c_gld_b_1 = MAKE_SC();
|
||||
|
||||
// compute buffer 1
|
||||
static_for<0, total_loops, 1>{}([&](auto i_issue) {
|
||||
gemm_0(acc_0, as[I1], gs[I1], i_issue);
|
||||
constexpr index_t slot = sr.at(i_issue);
|
||||
|
||||
if constexpr(slot & SLD_A)
|
||||
sld_a(as[I0], a_sld_win0, number<NEXT_SCI(c_sld_a_1, i_issue)>{});
|
||||
if constexpr(slot & GLD_A)
|
||||
gld_a(a_sst_win1, number<NEXT_SCI(c_gld_a_1, i_issue)>{});
|
||||
if constexpr(slot & GLD_B)
|
||||
gld_g(gs[I1], number<NEXT_SCI(c_gld_b_1, i_issue)>{});
|
||||
});
|
||||
move_g();
|
||||
move_a();
|
||||
block_sync_load_raw(issues_a + issues_g);
|
||||
lds_load_fence();
|
||||
};
|
||||
|
||||
auto pipeline_gemm0_tail = [&]() {
|
||||
constexpr index_t total_loops = issues_gemm0;
|
||||
constexpr auto sr = Policy::template GetSequencer_0<Problem>();
|
||||
static_assert(sr.size() == total_loops);
|
||||
|
||||
constexpr auto c_gld_b_0 = MAKE_SC();
|
||||
|
||||
// compute buffer 0
|
||||
static_for<0, total_loops, 1>{}([&](auto i_issue) {
|
||||
gemm_0(acc_0, as[I0], gs[I0], i_issue);
|
||||
constexpr index_t slot = sr.at(i_issue);
|
||||
|
||||
if constexpr(slot & GLD_B)
|
||||
gld_g(gs[I1], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
|
||||
});
|
||||
|
||||
block_sync_load_raw(issues_g);
|
||||
sld_a(as[I1], a_sld_win1, NEG1);
|
||||
|
||||
// compute buffer 1
|
||||
static_for<0, total_loops, 1>{}([&](auto i_issue) {
|
||||
constexpr auto last_nop = [&]() {
|
||||
if constexpr(i_issue == (total_loops - 1))
|
||||
return TRUE;
|
||||
else
|
||||
return FALSE;
|
||||
}();
|
||||
gemm_0(acc_0, as[I1], gs[I1], i_issue, last_nop); // last gemm has nop
|
||||
});
|
||||
};
|
||||
|
||||
auto y = Policy::template MakeYBlockTile<Problem>();
|
||||
|
||||
auto pipeline_bridge = [&]() {
|
||||
// cast to Y data
|
||||
auto y_pre = cast_tile<YDataType>(acc_0);
|
||||
store_tile(bridge_sst_win, y_pre);
|
||||
clear_tile(acc_1s(I0));
|
||||
// wave_barrier();
|
||||
load_tile(y, bridge_sld_win);
|
||||
clear_tile(acc_1s(I1));
|
||||
};
|
||||
|
||||
// note, gemm-1 start from idx-1 to N-2 (0, 1, 2....N-1)
|
||||
auto pipeline_gemm1 = [&]() {
|
||||
constexpr index_t total_loops = issues_gemm1;
|
||||
constexpr auto sr = Policy::template GetSequencer_1<Problem>();
|
||||
static_assert(sr.size() == total_loops);
|
||||
|
||||
constexpr auto c_gld_b_0 = MAKE_SC();
|
||||
constexpr auto c_gst_o_0 = MAKE_SC();
|
||||
constexpr auto c_gld_b_1 = MAKE_SC();
|
||||
constexpr auto c_gst_o_1 = MAKE_SC();
|
||||
|
||||
// compute buffer 0
|
||||
static_for<0, total_loops, 1>{}([&](auto i_issue) {
|
||||
gemm_1(acc_1s[I1], y, ds[I1], i_issue);
|
||||
constexpr index_t slot = sr.at(i_issue);
|
||||
if constexpr(slot & GLD_B)
|
||||
gld_d(ds[I0], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
|
||||
|
||||
if constexpr(slot & GST_O)
|
||||
{
|
||||
auto out = cast_tile<ODataType>(acc_1s[I0]);
|
||||
atomic_add_o(out, number<NEXT_SCI(c_gst_o_0, i_issue)>{});
|
||||
}
|
||||
});
|
||||
move_d();
|
||||
// move_o();
|
||||
|
||||
// compute buffer 1
|
||||
static_for<0, total_loops, 1>{}([&](auto i_issue) {
|
||||
gemm_1(acc_1s[I0], y, ds[I0], i_issue);
|
||||
constexpr index_t slot = sr.at(i_issue);
|
||||
if constexpr(slot & GLD_B)
|
||||
gld_d(ds[I1], number<NEXT_SCI(c_gld_b_1, i_issue)>{});
|
||||
|
||||
if constexpr(slot & GST_O)
|
||||
{
|
||||
auto out = cast_tile<ODataType>(acc_1s[I1]);
|
||||
atomic_add_o(out, number<NEXT_SCI(c_gst_o_1, i_issue)>{});
|
||||
}
|
||||
});
|
||||
move_d();
|
||||
};
|
||||
|
||||
auto pipeline_gemm1_head = [&]() {
|
||||
constexpr index_t total_loops = issues_gemm1;
|
||||
constexpr auto sr = Policy::template GetSequencer_1<Problem>();
|
||||
static_assert(sr.size() == total_loops);
|
||||
|
||||
constexpr auto c_gld_b_0 = MAKE_SC();
|
||||
|
||||
// compute buffer 0
|
||||
static_for<0, total_loops, 1>{}([&](auto i_issue) {
|
||||
gemm_1(acc_1s[I0], y, ds[I0], i_issue);
|
||||
constexpr index_t slot = sr.at(i_issue);
|
||||
if constexpr(slot & GLD_B)
|
||||
gld_d(ds[I1], number<NEXT_SCI(c_gld_b_0, i_issue)>{});
|
||||
});
|
||||
move_d();
|
||||
};
|
||||
auto pipeline_gemm1_tail = [&]() {
|
||||
constexpr index_t total_loops = issues_gemm1;
|
||||
constexpr auto sr = Policy::template GetSequencer_1<Problem>();
|
||||
static_assert(sr.size() == total_loops);
|
||||
|
||||
constexpr auto c_gst_o_0 = MAKE_SC();
|
||||
|
||||
// compute buffer 1
|
||||
static_for<0, total_loops, 1>{}([&](auto i_issue) {
|
||||
gemm_1(acc_1s[I1], y, ds[I1], i_issue);
|
||||
|
||||
constexpr index_t slot = sr.at(i_issue);
|
||||
if constexpr(slot & GST_O)
|
||||
{
|
||||
auto out = cast_tile<ODataType>(acc_1s[I0]);
|
||||
atomic_add_o(out, number<NEXT_SCI(c_gst_o_0, i_issue)>{});
|
||||
}
|
||||
});
|
||||
{
|
||||
auto out = cast_tile<ODataType>(acc_1s[I1]);
|
||||
atomic_add_o(out, NEG1);
|
||||
}
|
||||
};
|
||||
|
||||
// start of pipeline
|
||||
// clang-format off
|
||||
gld_a(a_sst_win0, NEG1, TRUE);
|
||||
gld_g(gs[I0], NEG1, TRUE);
|
||||
move_a();
|
||||
move_g();
|
||||
clear_tile(acc_0);
|
||||
|
||||
// preload for next round
|
||||
gld_a(a_sst_win1, NEG1);
|
||||
gld_g(gs[I1], NEG1);
|
||||
|
||||
// make sure a,g loaded
|
||||
block_sync_load_raw(issues_a + issues_g);
|
||||
lds_load_fence();
|
||||
|
||||
// we manually unroll double buffer inside hot loop
|
||||
const index_t iters_0 = (num_blocks_k0 - 2) / 2;
|
||||
index_t i_0 = 0; // (void)i_0; (void)iters_0; (void)pipeline_gemm0;
|
||||
while(i_0++ < iters_0)
|
||||
{
|
||||
pipeline_gemm0();
|
||||
}
|
||||
pipeline_gemm0_tail();
|
||||
|
||||
pipeline_bridge();
|
||||
|
||||
const index_t iters_1 = (num_blocks_n1 - 2) / 2;
|
||||
index_t i_1 = 0; // (void) i_1; (void)iters_1; (void)pipeline_gemm1;
|
||||
pipeline_gemm1_head();
|
||||
while(i_1++ < iters_1)
|
||||
{
|
||||
pipeline_gemm1();
|
||||
}
|
||||
pipeline_gemm1_tail();
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,831 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct FusedMoeGemmPipelineFlatmmPolicy
|
||||
{
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords()
|
||||
{
|
||||
// TODO: always 1 dword
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A()
|
||||
{
|
||||
// using async
|
||||
constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords();
|
||||
constexpr index_t data_bytes = sizeof(typename Problem::ADataType);
|
||||
static_assert(copy_bytes % data_bytes == 0);
|
||||
return copy_bytes / data_bytes;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_G()
|
||||
{
|
||||
constexpr index_t copy_bytes = [&]() { return 16; }();
|
||||
constexpr index_t data_bytes = sizeof(typename Problem::GDataType);
|
||||
static_assert(copy_bytes % data_bytes == 0);
|
||||
return copy_bytes / data_bytes;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_D()
|
||||
{
|
||||
constexpr index_t copy_bytes = [&]() { return 16; }();
|
||||
constexpr index_t data_bytes = sizeof(typename Problem::DDataType);
|
||||
static_assert(copy_bytes % data_bytes == 0);
|
||||
return copy_bytes / data_bytes;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_O()
|
||||
{
|
||||
if constexpr(Problem::Traits::OAtomic == 1)
|
||||
{
|
||||
// pack fp16/bf16 atomic
|
||||
static_assert(sizeof(typename Problem::ODataType) == 2);
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(Problem::Traits::OAtomic == 2)
|
||||
{
|
||||
// fp32 atomic
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 16 / sizeof(typename Problem::ODataType);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType_>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack()
|
||||
{
|
||||
// TODO: this is for 3d layout
|
||||
return 16 / sizeof(remove_cvref_t<DataType_>);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_A()
|
||||
{
|
||||
return GetSmemKPack<typename Problem::ADataType>();
|
||||
}
|
||||
|
||||
// used for bridge LDS shuffle
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_Y()
|
||||
{
|
||||
// TODO: this should match mfma layout
|
||||
return 16 / sizeof(typename Problem::YDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
|
||||
{
|
||||
constexpr auto a_sld_desc = MakeLdsLoadDesc_A<Problem>();
|
||||
constexpr auto a_sst_desc = MakeLdsStoreDesc_A<Problem>();
|
||||
static_assert(a_sld_desc.get_element_space_size() == a_sst_desc.get_element_space_size());
|
||||
return a_sld_desc.get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_Bridge()
|
||||
{
|
||||
constexpr auto bridge_sld_desc = MakeBridgeLdsLoadDesc<Problem>();
|
||||
constexpr auto bridge_sst_desc = MakeBridgeLdsStoreDesc<Problem>();
|
||||
static_assert(bridge_sld_desc.get_element_space_size() ==
|
||||
bridge_sst_desc.get_element_space_size());
|
||||
return bridge_sld_desc.get_element_space_size();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t a_lds = GetSmemSize_A<Problem>();
|
||||
constexpr index_t bridge_lds = GetSmemSize_Bridge<Problem>();
|
||||
return max(a_lds, bridge_lds);
|
||||
}
|
||||
|
||||
template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK()
|
||||
{
|
||||
constexpr index_t K_vec = Alignment;
|
||||
constexpr index_t K_rem = KPerBlock / K_vec;
|
||||
|
||||
if constexpr(get_warp_size() < K_rem)
|
||||
{
|
||||
static_assert(K_rem % get_warp_size() == 0);
|
||||
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
|
||||
constexpr index_t K_wav = K_rem / get_warp_size();
|
||||
static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet");
|
||||
constexpr index_t M_wav = NumWarps / K_wav;
|
||||
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
|
||||
constexpr index_t M_rep = MPerBlock / M_wav;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
|
||||
tuple<sequence<1, 2>, sequence<2>>,
|
||||
tuple<sequence<1, 0>, sequence<1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K_lan = K_rem;
|
||||
constexpr index_t M_lan = get_warp_size() / K_lan;
|
||||
constexpr index_t M_wav = NumWarps;
|
||||
static_assert(MPerBlock % (M_lan * M_wav) == 0,
|
||||
"this tile size is too small please check");
|
||||
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<M_rep, M_wav, M_lan>, sequence<K_lan, K_vec>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
}
|
||||
|
||||
// optimized version for async, not same as simple MXK dist(pay attention!!)
|
||||
template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK_Async()
|
||||
{
|
||||
constexpr index_t K_vec = Alignment;
|
||||
constexpr index_t K_rem = KPerBlock / K_vec;
|
||||
|
||||
if constexpr(get_warp_size() <= K_rem)
|
||||
{
|
||||
static_assert(K_rem % get_warp_size() == 0);
|
||||
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
|
||||
constexpr index_t K_wav = K_rem / get_warp_size();
|
||||
static_assert(K_wav <= NumWarps, "do not support thread has repeat along K yet");
|
||||
constexpr index_t M_wav = NumWarps / K_wav;
|
||||
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
|
||||
constexpr index_t M_rep = MPerBlock / M_wav;
|
||||
// NOTE: no swap, but hard to avoid LDS bank conflict
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
|
||||
tuple<sequence<1, 2>, sequence<2>>,
|
||||
tuple<sequence<1, 0>, sequence<1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 2>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K_lan = K_rem;
|
||||
constexpr index_t M_lan = get_warp_size() / K_lan;
|
||||
constexpr index_t M_wav = NumWarps;
|
||||
static_assert(MPerBlock % (M_lan * M_wav) == 0,
|
||||
"this tile size is too small please check");
|
||||
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
|
||||
// NOTE: swapped for LDS load bank conflict free
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
// Note M_wave(num waves) is the fastest dim, different from sipmle 2d
|
||||
// distribution
|
||||
tuple<sequence<M_rep, M_lan, M_wav>, sequence<K_lan, K_vec>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<2>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t WarpPerBlock_N_,
|
||||
index_t WarpPerBlock_K_,
|
||||
index_t Repeat_N_,
|
||||
index_t Repeat_K_,
|
||||
index_t WarpSize_,
|
||||
index_t Alignment_>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_Nr_Kr_W()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Repeat_N_, WarpPerBlock_N_>,
|
||||
sequence<Repeat_K_, WarpPerBlock_K_>,
|
||||
sequence<WarpSize_, Alignment_>>,
|
||||
tuple<sequence<1, 2>, sequence<3>>,
|
||||
tuple<sequence<1, 1>, sequence<0>>,
|
||||
sequence<1, 2, 3>,
|
||||
sequence<0, 0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_A()
|
||||
{
|
||||
constexpr index_t Block_M_ = Problem::BlockShape::Block_M0;
|
||||
constexpr index_t Block_K_ = Problem::BlockShape::Block_K0;
|
||||
constexpr index_t NumWarps_ = Problem::BlockShape::NumWarps;
|
||||
constexpr index_t Alignment_ = GetAlignment_A<Problem>();
|
||||
return MakeGlobalTileDistribution_SimpleMxK_Async<Block_M_,
|
||||
Block_K_,
|
||||
NumWarps_,
|
||||
Alignment_>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
|
||||
{
|
||||
constexpr auto PermuteEnum = Problem::Traits::PermuteEnum;
|
||||
// constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2;
|
||||
using S_ = typename Problem::BlockShape;
|
||||
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
|
||||
{
|
||||
// number<S_::WarpPerBlock_N0>{}.rrr();
|
||||
// number<S_::Repeat_N0>{}.eee();
|
||||
return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N0,
|
||||
S_::WarpPerBlock_K0,
|
||||
S_::Repeat_N0, /// hidden_radio_0,
|
||||
S_::Repeat_K0,
|
||||
get_warp_size(),
|
||||
GetAlignment_G<Problem>()>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D()
|
||||
{
|
||||
constexpr auto PermuteEnum = Problem::Traits::PermuteEnum;
|
||||
using S_ = typename Problem::BlockShape;
|
||||
if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
|
||||
{
|
||||
return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N1,
|
||||
S_::WarpPerBlock_K1,
|
||||
S_::Repeat_N1,
|
||||
S_::Repeat_K1,
|
||||
get_warp_size(),
|
||||
GetAlignment_D<Problem>()>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O()
|
||||
{
|
||||
using S_ = remove_cvref_t<typename Problem::BlockShape>;
|
||||
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
|
||||
// using CDataType = typename WarpGemm::CDataType;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
|
||||
sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
return c_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
|
||||
{
|
||||
// A async->LDS
|
||||
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
|
||||
constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
|
||||
constexpr index_t KPad = KPack; // pad between warps
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
if constexpr(LanesPerK >= warpSize)
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % warpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / warpSize;
|
||||
if constexpr(wavesPerK > NumWarps)
|
||||
{
|
||||
// TODO: need multiple issues along K to load all data
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t wavesPerM = NumWarps / wavesPerK;
|
||||
constexpr index_t NumIssues = Block_M / wavesPerM;
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<wavesPerM>{}, // m1
|
||||
number<wavesPerK>{}, // k0
|
||||
number<warpSize>{}, // k1
|
||||
number<KVector>{}), // k2
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // k0
|
||||
number<KVector>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<NumIssues>{}),
|
||||
make_merge_transform(make_tuple(number<wavesPerM>{}, number<wavesPerK>{})),
|
||||
make_merge_transform(make_tuple(number<warpSize>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
return lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(warpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
|
||||
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<LaneGroups>{}, // m1
|
||||
number<NumWarps>{}, // m2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<Block_K>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // m2
|
||||
number<KVector>{}, // k0
|
||||
number<1>{}), // k1
|
||||
number<KVector>{}, // lds store vector(actually no explicit store)
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<NumIssues>{}),
|
||||
make_pass_through_transform(number<NumWarps>{}),
|
||||
make_merge_transform(make_tuple(
|
||||
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
|
||||
|
||||
return lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadDesc_A()
|
||||
{
|
||||
// A async->LDS
|
||||
// Note that, this descriptor is only to construct the layout inside LDS
|
||||
// in real Gemm pipeline, ds_read may not follow this pattern
|
||||
// (may follow that in tile_distribution)
|
||||
// below code is almost the same as SmemStore dist, with difference:
|
||||
// 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
|
||||
// 2). return discriptor is in NxK 2d layout
|
||||
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t warpSize = ck_tile::get_warp_size();
|
||||
constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
|
||||
constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
|
||||
constexpr index_t KPad = KPack; // pad between warps
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
if constexpr(LanesPerK >= warpSize)
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % warpSize == 0);
|
||||
constexpr index_t wavesPerK = LanesPerK / warpSize;
|
||||
if constexpr(wavesPerK >= NumWarps)
|
||||
{
|
||||
// TODO: need multiple issues along K to load all data
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t wavesPerM = NumWarps / wavesPerK;
|
||||
constexpr index_t NumIssues = Block_M / wavesPerM;
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<wavesPerM>{}, // m1
|
||||
number<wavesPerK>{}, // k0
|
||||
number<warpSize>{}, // k1
|
||||
number<KVector>{}), // k2
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<wavesPerK*(warpSize * KVector + KPad)>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // k0
|
||||
number<KVector>{}, // k1
|
||||
number<1>{}), // k2
|
||||
number<KPack>{}, // lds load vector
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<NumIssues>{}, number<wavesPerM>{})),
|
||||
make_merge_transform(make_tuple(
|
||||
number<wavesPerK>{}, number<warpSize>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_desc_m_k;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(warpSize % LanesPerK == 0);
|
||||
constexpr index_t LaneGroups = warpSize / LanesPerK; // along m
|
||||
constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumIssues>{}, // m0
|
||||
number<LaneGroups>{}, // m1
|
||||
number<NumWarps>{}, // m2
|
||||
number<LanesPerK>{}, // k0
|
||||
number<KVector>{}), // k1
|
||||
make_tuple(number<NumWarps*(warpSize * KVector + KPad)>{}, // m0
|
||||
number<Block_K>{}, // m1
|
||||
number<warpSize * KVector + KPad>{}, // m2
|
||||
number<KVector>{}, // k0
|
||||
number<1>{}), // k1
|
||||
number<KPack>{}, // lds load vector
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_desc_m_k = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
|
||||
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
|
||||
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_desc_m_k;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsLoadDesc()
|
||||
{
|
||||
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
|
||||
|
||||
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
|
||||
constexpr index_t KPad = 0; // pad between warps
|
||||
|
||||
constexpr auto desc =
|
||||
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
make_tuple(number<Block_N + KPad>{}, number<1>{}),
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreDesc()
|
||||
{
|
||||
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
|
||||
|
||||
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
|
||||
constexpr index_t KPad = 0; // KVector; // pad between warps
|
||||
|
||||
constexpr auto desc =
|
||||
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
|
||||
make_tuple(number<Block_N + KPad>{}, number<1>{}),
|
||||
number<KVector>{},
|
||||
number<1>{});
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsStoreForUKDesc()
|
||||
{
|
||||
constexpr index_t WarpPerBlock_N = Problem::BlockShape::WarpPerBlock_N0;
|
||||
constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N0;
|
||||
constexpr index_t Repeat_M = Problem::BlockShape::Repeat_M0;
|
||||
|
||||
constexpr index_t kAMLane = 16;
|
||||
constexpr index_t kABKLane = 4;
|
||||
constexpr index_t kABKPerLane = 4;
|
||||
|
||||
constexpr index_t KPack = kABKPerLane;
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<Repeat_M>{}, // m
|
||||
number<Repeat_N>{}, // n
|
||||
number<WarpPerBlock_N>{}, // n
|
||||
number<kABKLane>{}, // n
|
||||
number<kAMLane>{}, // m
|
||||
number<KPack>{}), // n
|
||||
make_tuple(number<Repeat_N * WarpPerBlock_N * kABKLane * kAMLane * KPack>{}, // m
|
||||
number<WarpPerBlock_N * kABKLane * kAMLane * KPack>{}, // n
|
||||
number<kABKLane * kAMLane * KPack>{}, // n
|
||||
number<kAMLane * KPack>{}, // n
|
||||
number<KPack>{}, // m
|
||||
number<1>{}), // n
|
||||
number<KPack>{}, // lds store vector(actually no explicit store)
|
||||
number<1>{});
|
||||
|
||||
constexpr auto desc = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(make_tuple(number<Repeat_M>{}, number<kAMLane>{})),
|
||||
make_merge_transform(make_tuple(number<Repeat_N>{},
|
||||
number<WarpPerBlock_N>{},
|
||||
number<kABKLane>{},
|
||||
number<KPack>{}))),
|
||||
make_tuple(sequence<0, 4>{}, sequence<1, 2, 3, 5>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm0()
|
||||
{
|
||||
using S_ = typename Problem::BlockShape;
|
||||
// A is vgpr, B is agpr. But since we transposed, so also need swap this
|
||||
// TODO: this is ugly
|
||||
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
|
||||
// TODO: ugly
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
|
||||
2>>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::int8_t> &&
|
||||
std::is_same_v<typename Problem::GDataType, ck_tile::int8_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<wg_ctrl>,
|
||||
2>>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_0()
|
||||
{
|
||||
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
|
||||
// the purpose is to hide thoes instructions under mfma
|
||||
// every value inside seq<...> is a mask, indicating a specific operation
|
||||
using S_ = typename Problem::BlockShape;
|
||||
constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
|
||||
constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
|
||||
constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
|
||||
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
|
||||
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
|
||||
S_::Block_N1 == 128)
|
||||
{
|
||||
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
|
||||
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
|
||||
// clang-format off
|
||||
constexpr auto seq_all =
|
||||
// 0 1 2 3 4 5 6 7
|
||||
sequence<GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 0
|
||||
GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 1
|
||||
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 2
|
||||
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 3
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 4
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 5
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 6
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 7
|
||||
return seq_all;
|
||||
// clang-format on
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
|
||||
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
|
||||
S_::Block_N1 == 128)
|
||||
{
|
||||
// Total 32 instructions, 16 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
|
||||
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
|
||||
// clang-format off
|
||||
constexpr auto seq_all =
|
||||
// 0 1 2 3 4 5 6 7
|
||||
sequence<GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 0
|
||||
GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, GLD_B, GLD_A, // 1
|
||||
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, // 2
|
||||
GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A>{}; // 3
|
||||
return seq_all;
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_1()
|
||||
{
|
||||
// this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
|
||||
// the purpose is to hide thoes instructions under mfma
|
||||
// every value inside seq<...> is a mask, indicating a specific operation
|
||||
using S_ = typename Problem::BlockShape;
|
||||
constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
|
||||
constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
|
||||
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
|
||||
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
|
||||
S_::Block_N1 == 128)
|
||||
{
|
||||
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
|
||||
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
|
||||
// clang-format off
|
||||
constexpr auto seq_all =
|
||||
// 0 1 2 3 4 5 6 7
|
||||
sequence<GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 0
|
||||
GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 1
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 2
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 3
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 4
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 5
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 6
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 7
|
||||
return seq_all;
|
||||
// clang-format on
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
|
||||
S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
|
||||
S_::Block_N1 == 128)
|
||||
{
|
||||
// Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
|
||||
// gld_a 8x ds_read_b128 sld_a total 64 slot :)
|
||||
// clang-format off
|
||||
constexpr auto seq_all =
|
||||
// 0 1 2 3 4 5 6 7
|
||||
sequence<GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 0
|
||||
GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, GLD_B, GST_O, // 1
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 2
|
||||
GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 3
|
||||
return seq_all;
|
||||
// clang-format on
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1()
|
||||
{
|
||||
using S_ = typename Problem::BlockShape;
|
||||
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
|
||||
// TODO: ugly
|
||||
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
|
||||
2>>{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::int8_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::int8_t> &&
|
||||
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
|
||||
{
|
||||
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<wg_ctrl>,
|
||||
2>>{};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm0()
|
||||
{
|
||||
using S_ = remove_cvref_t<typename Problem::BlockShape>;
|
||||
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm0<Problem>())>;
|
||||
using CDataType = typename WarpGemm::CDataType;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<S_::Repeat_M0, S_::WarpPerBlock_M0>,
|
||||
sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeCBlockTile_Gemm1()
|
||||
{
|
||||
using S_ = remove_cvref_t<typename Problem::BlockShape>;
|
||||
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
|
||||
using CDataType = typename WarpGemm::CDataType;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
|
||||
sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// this is used as A matrix for 2nd gemm
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution()
|
||||
{
|
||||
using S_ = remove_cvref_t<typename Problem::BlockShape>;
|
||||
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
|
||||
|
||||
// TODO: all waves a along different N, but same M
|
||||
constexpr auto y_outer_dstr_enc =
|
||||
tile_distribution_encoding<sequence<S_::WarpPerBlock_M1>,
|
||||
tuple<sequence<S_::Repeat_M1>, sequence<S_::Repeat_K1>>,
|
||||
tuple<sequence<0>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{});
|
||||
constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode);
|
||||
return y_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeYBlockTile()
|
||||
{
|
||||
constexpr auto y_block_dstr = MakeYTileDistribution<Problem>();
|
||||
auto y_block_tensor =
|
||||
make_static_distributed_tensor<typename Problem::YDataType>(y_block_dstr);
|
||||
return y_block_tensor;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetUK_0()
|
||||
{
|
||||
using S_ = typename Problem::BlockShape;
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
|
||||
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
|
||||
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
|
||||
{
|
||||
return Flatmm_32x512x128_1x4x1_16x16x32_BF16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::fp16_t> &&
|
||||
std::is_same_v<typename Problem::GDataType, ck_tile::fp16_t> &&
|
||||
S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
|
||||
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
|
||||
{
|
||||
return Flatmm_32x512x128_1x4x1_16x16x32_FP16{};
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetUK_1()
|
||||
{
|
||||
using S_ = typename Problem::BlockShape;
|
||||
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
|
||||
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
|
||||
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
|
||||
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
|
||||
{
|
||||
return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
|
||||
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
|
||||
std::is_same_v<typename Problem::TopkWeightDataType, float> &&
|
||||
S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
|
||||
S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
|
||||
{
|
||||
return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,354 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
|
||||
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
|
||||
|
||||
<----- gemm-N ------>
|
||||
+----+----+----+----+
|
||||
| w0 | w1 | w2 | w3 | gemm-m
|
||||
+----+----+----+----+
|
||||
*/
|
||||
template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
|
||||
struct FusedMoeGemmPipeline_FlatmmUk
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
|
||||
|
||||
using ADataType = typename Problem::ADataType;
|
||||
using GDataType = typename Problem::GDataType;
|
||||
using DDataType = typename Problem::DDataType;
|
||||
using AccDataType = typename Problem::AccDataType;
|
||||
using ODataType = typename Problem::ODataType;
|
||||
using AScaleDataType = typename Problem::AScaleDataType;
|
||||
using GScaleDataType = typename Problem::GScaleDataType;
|
||||
using DScaleDataType = typename Problem::DScaleDataType;
|
||||
using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
|
||||
using TopkWeightDataType = typename Problem::TopkWeightDataType;
|
||||
using IndexDataType = typename Problem::IndexDataType;
|
||||
using YDataType = typename Problem::YDataType;
|
||||
|
||||
using Traits = typename Problem::Traits;
|
||||
|
||||
static constexpr bool IsGateOnly = Traits::IsGateOnly;
|
||||
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
|
||||
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
|
||||
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
|
||||
|
||||
static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
|
||||
static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
|
||||
static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
|
||||
static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
|
||||
|
||||
static constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
|
||||
static constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
|
||||
static constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
|
||||
static constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
// minimize occupancy
|
||||
return 2;
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "flatmm_uk";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_0 = Policy::template GetUK_0<Problem>().GetSmemSize();
|
||||
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
|
||||
constexpr index_t smem_bridge =
|
||||
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
|
||||
return max(smem_0, max(smem_1, smem_bridge));
|
||||
}
|
||||
|
||||
// this is the thread-offset along row/col
|
||||
CK_TILE_HOST_DEVICE static auto GetACoord()
|
||||
{
|
||||
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
|
||||
const auto a_coord = a_dist.calculate_index();
|
||||
return a_coord;
|
||||
}
|
||||
|
||||
// this is the thread-offset along row/col
|
||||
CK_TILE_HOST_DEVICE static auto GetOCoord()
|
||||
{
|
||||
constexpr auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
|
||||
const auto o_coord = o_dist.calculate_index();
|
||||
return o_coord;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto GetNumRowCoords_A()
|
||||
{
|
||||
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
|
||||
constexpr index_t MLans = BlockShape::BlockSize / KLans;
|
||||
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
|
||||
|
||||
return MRepeat;
|
||||
}
|
||||
|
||||
// TODO: properlly support scatter/gather
|
||||
CK_TILE_DEVICE auto GetRowCoords_A(index_t base_offset)
|
||||
{
|
||||
constexpr index_t KLans = BlockShape::Block_K0 / kAlignmentA;
|
||||
constexpr index_t MLans = BlockShape::BlockSize / KLans;
|
||||
constexpr index_t MRepeat = BlockShape::Block_M0 / MLans;
|
||||
|
||||
auto base_coord = threadIdx.x / KLans + base_offset;
|
||||
|
||||
array<index_t, MRepeat> coords;
|
||||
static_for<0, MRepeat, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLans; });
|
||||
|
||||
return coords;
|
||||
}
|
||||
|
||||
template <typename ROW_COORDS>
|
||||
CK_TILE_DEVICE auto GetRowID(const ROW_COORDS coords, const IndexDataType* sorted_token_ids_ptr)
|
||||
{
|
||||
constexpr index_t n_size = coords.size();
|
||||
|
||||
array<index_t, n_size> row_ids;
|
||||
static_for<0, n_size, 1>{}([&](auto i) {
|
||||
row_ids.at(i) = sorted_token_ids_ptr[coords[i]]; // base_coord + i * MLans;
|
||||
});
|
||||
|
||||
return row_ids;
|
||||
}
|
||||
|
||||
template <typename ROW_COORDS>
|
||||
CK_TILE_DEVICE auto GetWeightScale(const ROW_COORDS coords,
|
||||
const TopkWeightDataType* sorted_weight_ptr)
|
||||
{
|
||||
constexpr index_t n_size = coords.size();
|
||||
|
||||
array<TopkWeightDataType, n_size> w;
|
||||
static_for<0, n_size, 1>{}([&](auto i) {
|
||||
w.at(i) = sorted_weight_ptr[coords[i]]; // base_coord + i * MLans;
|
||||
});
|
||||
|
||||
return w;
|
||||
}
|
||||
|
||||
// TODO: this row id is before shuffle atomic, need use acc distribution
|
||||
CK_TILE_DEVICE auto GetRowCoords_O(index_t base_offset)
|
||||
{
|
||||
constexpr index_t MLanes = BlockShape::Warp_M1;
|
||||
constexpr index_t Repeat_M = BlockShape::Repeat_M1;
|
||||
|
||||
auto base_coord = threadIdx.x % MLanes + base_offset;
|
||||
|
||||
array<index_t, Repeat_M> coords;
|
||||
static_for<0, Repeat_M, 1>{}([&](auto i) { coords.at(i) = base_coord + i * MLanes; });
|
||||
|
||||
return coords;
|
||||
}
|
||||
|
||||
template <typename Karg>
|
||||
CK_TILE_DEVICE auto operator()(const Karg& kargs,
|
||||
CK_TILE_LDS_ADDR void* smem,
|
||||
index_t sorted_tile_id,
|
||||
index_t intermediate_tile_id)
|
||||
{
|
||||
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
|
||||
ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size;
|
||||
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0;
|
||||
|
||||
index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
|
||||
index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
|
||||
index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1;
|
||||
index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1;
|
||||
|
||||
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
|
||||
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
|
||||
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
|
||||
index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
|
||||
|
||||
// nr*kr*w
|
||||
index_t interm_idx_nr0 = __builtin_amdgcn_readfirstlane(
|
||||
intermediate_tile_id *
|
||||
BlockShape::Block_Nr0); // intermediate_tile_id * Block_N / (N in W)
|
||||
|
||||
index_t interm_idx_kr1 = __builtin_amdgcn_readfirstlane(
|
||||
intermediate_tile_id *
|
||||
BlockShape::Block_Kr1); // intermediate_tile_id * Block_N / (N in W)
|
||||
|
||||
auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0);
|
||||
auto row_ids_a = GetRowID(
|
||||
row_coords_a, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
|
||||
auto a_coords = generate_tuple(
|
||||
[&](auto i) {
|
||||
return row_ids_a[i] * kargs.stride_token +
|
||||
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
|
||||
},
|
||||
number<row_ids_a.size()>{});
|
||||
auto a_res =
|
||||
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
|
||||
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
|
||||
|
||||
auto g_win = [&]() {
|
||||
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
|
||||
static_cast<long_index_t>(expert_id) * expert_stride_0 +
|
||||
interm_idx_nr0 * kr_0 * BlockShape::Block_W0;
|
||||
auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
g_ptr,
|
||||
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
|
||||
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
|
||||
number<kAlignmentG>{},
|
||||
number<1>{});
|
||||
|
||||
auto g_window_ = make_tile_window_linear_raw(
|
||||
g_view_,
|
||||
make_tuple(number<BlockShape::Block_Nr0>{},
|
||||
number<BlockShape::Block_Kr0>{},
|
||||
number<BlockShape::Block_W0>{}),
|
||||
{0, 0, 0},
|
||||
Policy::template MakeGlobalTileDistribution_G<Problem>(),
|
||||
sequence<0, 1, 1>{});
|
||||
return g_window_;
|
||||
}();
|
||||
|
||||
auto g_res = g_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
|
||||
auto g_coords = generate_tuple([&](auto i) { return g_win.cached_coords_[i].get_offset(); },
|
||||
number<decltype(g_win)::NumAccess_NonLinear>{});
|
||||
|
||||
const auto d_win = [&]() {
|
||||
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
|
||||
static_cast<long_index_t>(expert_id) * expert_stride_1 +
|
||||
interm_idx_kr1 * BlockShape::Block_W1;
|
||||
// note interm_idx_nr0 is along the gemm-k dim of 2nd gemm
|
||||
|
||||
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
|
||||
d_ptr,
|
||||
make_tuple(nr_1, kr_1, BlockShape::Block_W1),
|
||||
make_tuple(kr_1 * BlockShape::Block_W1, BlockShape::Block_W1, 1),
|
||||
number<kAlignmentD>{},
|
||||
number<1>{});
|
||||
|
||||
const auto d_window_ = make_tile_window_linear_raw(
|
||||
d_view_,
|
||||
make_tuple(number<BlockShape::Block_Nr1>{},
|
||||
number<BlockShape::Block_Kr1>{},
|
||||
number<BlockShape::Block_W1>{}),
|
||||
{0, 0, 0},
|
||||
Policy::template MakeGlobalTileDistribution_D<Problem>(),
|
||||
sequence<0, 1, 1>{});
|
||||
return d_window_;
|
||||
}();
|
||||
auto d_res = d_win.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
|
||||
|
||||
// TODO: load D order is N0.K0...127, N64.K0...127, N0.K128...255, N64.K128...255
|
||||
// block-k=512, block-n=128
|
||||
// wg |<----- W_ ----->|
|
||||
// Nr(2)*Nw(4)* Kr *Kr0(4)*Kr1(4) * [Kl(4)*Nl(16)*Kv(8)]->one issue
|
||||
// y p y y p p y
|
||||
// 1 2 0(imm)
|
||||
auto d_coords = [&]() {
|
||||
constexpr index_t Nr_ = 2;
|
||||
constexpr index_t Nw_ = 4;
|
||||
constexpr index_t Kr0_ = 4;
|
||||
constexpr index_t Kr1_ = 4;
|
||||
constexpr index_t Kl_ = 4;
|
||||
constexpr index_t Nl_ = 16;
|
||||
constexpr index_t Kv_ = 8;
|
||||
constexpr index_t W_ = Kl_ * Nl_ * Kv_;
|
||||
constexpr index_t num_offsets_ = Nr_ * Kr0_;
|
||||
index_t base_os_ = (threadIdx.x % 64) * Kv_ + (threadIdx.x / 64) *
|
||||
shared_intermediate_size_1 *
|
||||
Nl_; // Kr0_ * Kr1_ * W_;
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr auto i_nr_ = number<i % Nr_>{};
|
||||
constexpr auto i_kr0_ = number<i / Nr_>{};
|
||||
|
||||
return i_nr_ * shared_intermediate_size_1 * Nw_ * Nl_ + i_kr0_ * Kr1_ * W_ +
|
||||
base_os_;
|
||||
},
|
||||
number<num_offsets_>{});
|
||||
}();
|
||||
|
||||
auto o_coords = generate_tuple(
|
||||
[&](auto i) {
|
||||
return row_ids_a[i] * kargs.stride_token +
|
||||
threadIdx.x % (BlockShape::Block_N1 / kAlignmentO) * kAlignmentO;
|
||||
},
|
||||
number<row_ids_a.size()>{});
|
||||
|
||||
auto o_flags =
|
||||
generate_tuple([&](auto i) { return cmp_lt_to_exec(row_ids_a[i], kargs.num_tokens); },
|
||||
number<row_ids_a.size()>{});
|
||||
|
||||
auto bridge_sst_win = [&]() {
|
||||
constexpr auto desc_ = Policy::template MakeBridgeLdsStoreForUKDesc<Problem>();
|
||||
constexpr auto dist_ = Policy::template GetUK_0<Problem>().MakeCBlockDist();
|
||||
return make_tile_window_linear(make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<YDataType*>(smem), desc_),
|
||||
desc_.get_lengths(),
|
||||
{0, 0},
|
||||
dist_);
|
||||
}();
|
||||
auto o_res =
|
||||
make_wave_buffer_resource(reinterpret_cast<const ODataType*>(kargs.o_ptr),
|
||||
kargs.num_tokens * kargs.stride_token * sizeof(ODataType));
|
||||
|
||||
auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0);
|
||||
auto w_scale = GetWeightScale(
|
||||
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
|
||||
|
||||
auto uk_0 = Policy::template GetUK_0<Problem>();
|
||||
auto acc_0 = uk_0(a_res,
|
||||
a_coords,
|
||||
g_res,
|
||||
g_coords,
|
||||
smem,
|
||||
kargs.hidden_size,
|
||||
BlockShape::Block_K0, // tile offset for B matrix each unroll
|
||||
BlockShape::Block_Kr0 *
|
||||
BlockShape::Block_W0); // tile offset for B matrix each unroll
|
||||
|
||||
sweep_tile(
|
||||
acc_0,
|
||||
[&](auto idx0, auto idx1) {
|
||||
fp32x2_t v_{acc_0(idx0), acc_0(idx1)};
|
||||
typename Problem::GateActivation{}(v_, v_);
|
||||
acc_0(idx0) = v_.x;
|
||||
acc_0(idx1) = v_.y;
|
||||
},
|
||||
sequence<1, 2>{});
|
||||
|
||||
auto y_pre = cast_tile<YDataType>(acc_0);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(bridge_sst_win, y_pre);
|
||||
block_sync_lds();
|
||||
|
||||
auto uk_1 = Policy::template GetUK_1<Problem>();
|
||||
uk_1(d_res,
|
||||
d_coords,
|
||||
o_res,
|
||||
o_coords,
|
||||
o_flags,
|
||||
smem,
|
||||
kargs.hidden_size, // total n number
|
||||
w_scale,
|
||||
BlockShape::Block_Nr1 * kr_1 * BlockShape::Block_W1, // along N
|
||||
BlockShape::Block_N1); // along N
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,46 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: alow 2 gemm have different type
|
||||
template <typename ADataType_,
|
||||
typename GDataType_,
|
||||
typename DDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename AScaleDataType_,
|
||||
typename GScaleDataType_,
|
||||
typename DScaleDataType_,
|
||||
typename YSmoothScaleDataType_,
|
||||
typename TopkWeightDataType_,
|
||||
typename IndexDataType_, // data type for all indexing
|
||||
typename GateActivation_, // = ck_tile::element_wise::Silu,
|
||||
typename BlockShape_, // shoule be FusedMoeGemmShape
|
||||
typename Traits_>
|
||||
struct FusedMoeGemmPipelineProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using GDataType = remove_cvref_t<GDataType_>;
|
||||
using DDataType = remove_cvref_t<DDataType_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using AScaleDataType = remove_cvref_t<AScaleDataType_>;
|
||||
using GScaleDataType = remove_cvref_t<GScaleDataType_>;
|
||||
using DScaleDataType = remove_cvref_t<DScaleDataType_>;
|
||||
using YSmoothScaleDataType = remove_cvref_t<YSmoothScaleDataType_>;
|
||||
using TopkWeightDataType = remove_cvref_t<TopkWeightDataType_>;
|
||||
using IndexDataType = remove_cvref_t<IndexDataType_>;
|
||||
|
||||
// the input for next gemm should have same time as
|
||||
using YDataType = ADataType;
|
||||
|
||||
using GateActivation = remove_cvref_t<GateActivation_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,48 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum class FusedMoeGemmWeightPermuteEnum
|
||||
{
|
||||
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
|
||||
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
|
||||
no_permute = 0,
|
||||
b_nr_kr_kw_nw_kv = 1, // 0,1,3,4,2,5
|
||||
b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
|
||||
};
|
||||
|
||||
template <bool IsGateOnly_,
|
||||
bool UseSmoothQuant_,
|
||||
index_t OAtomic_, // 0-no atomic, 1-atomic-pk-f16/bf16, 2-atomic-f32
|
||||
FusedMoeGemmWeightPermuteEnum PermuteEnum_ =
|
||||
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten,
|
||||
bool PadHiddenSize_ = false,
|
||||
bool PadIntermediateSize_ = false>
|
||||
struct FusedMoeGemmTraits
|
||||
{
|
||||
// Gate+Up or Gate only
|
||||
static constexpr bool IsGateOnly = IsGateOnly_;
|
||||
static constexpr bool UseSmoothQuant = UseSmoothQuant_;
|
||||
static constexpr index_t OAtomic = OAtomic_;
|
||||
static constexpr FusedMoeGemmWeightPermuteEnum PermuteEnum = PermuteEnum_;
|
||||
static constexpr bool PadHiddenSize = PadHiddenSize_;
|
||||
static constexpr bool PadIntermediateSize = PadIntermediateSize_;
|
||||
};
|
||||
|
||||
// Note: this need to be a bit mask
|
||||
enum class FusedMoeGemmPipelineSequencerEnum
|
||||
{
|
||||
SLD_A = 1 << 0, // shared load a
|
||||
SLD_B = 1 << 1,
|
||||
GLD_A = 1 << 2, // global load a
|
||||
GLD_B = 1 << 3,
|
||||
SST_A = 1 << 4, // shared store a
|
||||
SST_B = 1 << 5,
|
||||
GST_O = 1 << 6, // global store out
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -10,114 +10,134 @@
|
||||
namespace ck_tile {
|
||||
|
||||
// fp16
|
||||
using WarpGemmMfmaF16F16F32M32N32K8 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K16 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16>>;
|
||||
using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K16 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
|
||||
using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K32 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>;
|
||||
using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 1>>;
|
||||
using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
|
||||
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
1>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>;
|
||||
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M16N16K16>>;
|
||||
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8,
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16,
|
||||
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8,
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
// bf16
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K16 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16>>;
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>;
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>;
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 1>>;
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>;
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
1>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8>>;
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16>>;
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8,
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16,
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8,
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
|
||||
2>>;
|
||||
|
||||
// fp8
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_bf8 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8>>;
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_fp8 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8>>;
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_bf8 =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>;
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8>>;
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8>>;
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8>>;
|
||||
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl<
|
||||
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>;
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
template <index_t swizzle_factor = 2>
|
||||
using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t>,
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, WGAttrCtlEnum::Default_>,
|
||||
2,
|
||||
swizzle_factor>>;
|
||||
|
||||
|
||||
@@ -25,6 +25,8 @@ struct WarpGemmAtrributeMfma
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
@@ -51,10 +53,13 @@ struct WarpGemmAtrributeMfma
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
Impl{}(c_vec, a_vec, b_vec);
|
||||
Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
@@ -85,6 +90,8 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
@@ -111,8 +118,11 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
@@ -122,10 +132,33 @@ struct WarpGemmAtrributeMfmaIterateK
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
bool_constant<post_nop_>{});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t iKIter, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
number<iKIter>,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
static_assert(iKIter < kKIter);
|
||||
|
||||
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
bool_constant<post_nop_>{});
|
||||
//});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
@@ -168,6 +201,8 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
@@ -194,11 +229,14 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
// swap A and B
|
||||
Impl{}(c_vec, b_vec, a_vec);
|
||||
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
@@ -226,6 +264,8 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane>>,
|
||||
@@ -255,12 +295,15 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
template <bool post_nop_ = false>
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
// swap A and B
|
||||
Impl{}(c_vec, b_vec, a_vec);
|
||||
Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
@@ -291,6 +334,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
static constexpr index_t kN = Impl::kM;
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
@@ -316,9 +361,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>;
|
||||
|
||||
template <bool post_nop_ = false>
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
@@ -328,10 +376,34 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter]);
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
bool_constant<post_nop_>{});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t iKIter, bool post_nop_ = false>
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
number<iKIter>,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
static_assert(iKIter < kKIter);
|
||||
// swap A and B, value and type
|
||||
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
bool_constant<post_nop_>{});
|
||||
//});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
@@ -377,6 +449,8 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kBNLane>, sequence<Impl::kABKLane, Impl::kABKPerLane * kKIter>>,
|
||||
@@ -429,8 +503,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
sequence<0, 2>>;
|
||||
#endif
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
@@ -440,10 +517,33 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter]);
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
bool_constant<post_nop_>{});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t iKIter, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
number<iKIter>,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
static_assert(iKIter < kKIter);
|
||||
// swap A and B, value and type
|
||||
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
bool_constant<post_nop_>{});
|
||||
//});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
@@ -488,6 +588,8 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
|
||||
static constexpr index_t kK = Impl::kK * kKIter;
|
||||
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
|
||||
|
||||
using AWarpDstrEncoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
|
||||
@@ -518,8 +620,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
|
||||
sequence<0, 2>>;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
@@ -529,10 +634,33 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter]);
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
bool_constant<post_nop_>{});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t iKIter, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
number<iKIter>,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
|
||||
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
|
||||
|
||||
static_assert(iKIter < kKIter);
|
||||
|
||||
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
|
||||
Impl{}(c_vec,
|
||||
reinterpret_cast<const buf_a&>(a_vec)
|
||||
.template get_as<typename Impl::AVecType>()[iKIter],
|
||||
reinterpret_cast<const buf_b&>(b_vec)
|
||||
.template get_as<typename Impl::BVecType>()[iKIter],
|
||||
bool_constant<post_nop_>{});
|
||||
//});
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -7,12 +7,68 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// TODO: refactor warp-gemm
|
||||
// currently there is a discrepency for vav/vva if we need transpose C/D
|
||||
// e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum
|
||||
// because we swap the A/B pointer in _impl code (but not known this info here)
|
||||
enum class WGAttrCtlEnum
|
||||
{
|
||||
Default_ = 0,
|
||||
Raw_vvv = 1, // c-vgpr, a-vgpr, b-vgpr
|
||||
Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
|
||||
Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr
|
||||
Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr
|
||||
Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr
|
||||
// raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
|
||||
};
|
||||
|
||||
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
|
||||
if constexpr(post_nop_) \
|
||||
{ \
|
||||
asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \
|
||||
"s_nop 3" \
|
||||
: dmod_(c_vec) \
|
||||
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
|
||||
:); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
asm volatile(mfma_ " %0, %1, %2, %3\n" \
|
||||
: dmod_(c_vec) \
|
||||
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
|
||||
:); \
|
||||
}
|
||||
|
||||
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
|
||||
if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
|
||||
{ \
|
||||
DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
|
||||
} \
|
||||
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
|
||||
{ \
|
||||
DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
|
||||
} \
|
||||
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
|
||||
{ \
|
||||
DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
|
||||
} \
|
||||
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
|
||||
{ \
|
||||
DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
|
||||
} \
|
||||
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
|
||||
{ \
|
||||
DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
|
||||
}
|
||||
|
||||
// FP16
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
|
||||
{
|
||||
using ADataType = fp16_t;
|
||||
using BDataType = fp16_t;
|
||||
using CDataType = float;
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = fp16_t;
|
||||
using BDataType = fp16_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<fp16_t, 4>;
|
||||
using BVecType = ext_vector_t<fp16_t, 4>;
|
||||
@@ -33,16 +89,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8f16", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#else
|
||||
ignore = c_vec;
|
||||
ignore = a_vec;
|
||||
ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
@@ -52,18 +115,20 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
|
||||
#else
|
||||
ignore = a_vec;
|
||||
ignore = b_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
|
||||
{
|
||||
using ADataType = fp16_t;
|
||||
using BDataType = fp16_t;
|
||||
using CDataType = float;
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = fp16_t;
|
||||
using BDataType = fp16_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<fp16_t, 4>;
|
||||
using BVecType = ext_vector_t<fp16_t, 4>;
|
||||
@@ -84,16 +149,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx9__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#else
|
||||
ignore = c_vec;
|
||||
ignore = a_vec;
|
||||
ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
@@ -103,19 +175,21 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
|
||||
return bit_cast<CVecType>(
|
||||
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
|
||||
#else
|
||||
ignore = a_vec;
|
||||
ignore = b_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// Bf16
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
|
||||
{
|
||||
using ADataType = bf16_t;
|
||||
using BDataType = bf16_t;
|
||||
using CDataType = float;
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = bf16_t;
|
||||
using BDataType = bf16_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<bf16_t, 4>;
|
||||
using BVecType = ext_vector_t<bf16_t, 4>;
|
||||
@@ -136,28 +210,35 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8bf16_1k", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx90a__) || defined(__gfx94__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__)
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
|
||||
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
c_vec,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
});
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
|
||||
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
c_vec,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
});
|
||||
#else
|
||||
ignore = c_vec;
|
||||
ignore = a_vec;
|
||||
ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
@@ -181,18 +262,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
|
||||
});
|
||||
return c_vec;
|
||||
#else
|
||||
ignore = a_vec;
|
||||
ignore = b_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
|
||||
{
|
||||
using ADataType = bf16_t;
|
||||
using BDataType = bf16_t;
|
||||
using CDataType = float;
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = bf16_t;
|
||||
using BDataType = bf16_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<bf16_t, 4>;
|
||||
using BVecType = ext_vector_t<bf16_t, 4>;
|
||||
@@ -213,28 +296,34 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl)
|
||||
{
|
||||
#if defined(__gfx90a__) || defined(__gfx94__)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__)
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
|
||||
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
c_vec,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
});
|
||||
static_for<0, 2, 1>{}([&](auto k) {
|
||||
c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
|
||||
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
|
||||
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
|
||||
c_vec,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
});
|
||||
#else
|
||||
ignore = c_vec;
|
||||
ignore = a_vec;
|
||||
ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
@@ -258,20 +347,21 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
|
||||
});
|
||||
return c_vec;
|
||||
#else
|
||||
ignore = a_vec;
|
||||
ignore = b_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// FP8
|
||||
template <typename AType_, typename BType_>
|
||||
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
{
|
||||
using ADataType = AType_;
|
||||
using BDataType = BType_;
|
||||
using CDataType = float;
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = AType_;
|
||||
using BDataType = BType_;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 8>;
|
||||
using BVecType = ext_vector_t<BDataType, 8>;
|
||||
@@ -292,38 +382,120 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "v", "v")
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
{
|
||||
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "v", "v")
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "v", "v")
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
{
|
||||
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "v", "v")
|
||||
}
|
||||
}
|
||||
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "a", "v")
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
{
|
||||
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "a", "v")
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "a", "v")
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
{
|
||||
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "a", "v")
|
||||
}
|
||||
}
|
||||
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "v", "v")
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
{
|
||||
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "v", "v")
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "v", "v")
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
{
|
||||
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "v", "v")
|
||||
}
|
||||
}
|
||||
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "a", "v")
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
{
|
||||
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "a", "v")
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
{
|
||||
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "a", "v")
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
{
|
||||
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "a", "v")
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float a_f32 =
|
||||
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
|
||||
.template get_as<ADataType>()[number<k>{}]);
|
||||
float b_f32 =
|
||||
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
|
||||
.template get_as<BDataType>()[number<k>{}]);
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float a_f32 =
|
||||
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
|
||||
.template get_as<ADataType>()[number<k>{}]);
|
||||
float b_f32 =
|
||||
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
|
||||
.template get_as<BDataType>()[number<k>{}]);
|
||||
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
|
||||
});
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
|
||||
});
|
||||
#else
|
||||
ignore = c_vec;
|
||||
ignore = a_vec;
|
||||
ignore = b_vec;
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
@@ -356,20 +528,97 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
|
||||
});
|
||||
return c_vec;
|
||||
#else
|
||||
ignore = a_vec;
|
||||
ignore = b_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t>;
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, bf8_t>;
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, bf8_t, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, fp8_t>;
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, fp8_t, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, bf8_t>;
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, bf8_t, Ctrl_>;
|
||||
|
||||
// int8
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using CDataType = int32_t;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 8>;
|
||||
using BVecType = ext_vector_t<BDataType, 8>;
|
||||
using CVecType = ext_vector_t<CDataType, 16>;
|
||||
|
||||
static constexpr index_t kM = 32;
|
||||
static constexpr index_t kN = 32;
|
||||
static constexpr index_t kK = 16;
|
||||
|
||||
static constexpr index_t kAMLane = 32;
|
||||
static constexpr index_t kBNLane = 32;
|
||||
static constexpr index_t kABKLane = 2;
|
||||
static constexpr index_t kABKPerLane = 8;
|
||||
|
||||
static constexpr index_t kCMLane = 2;
|
||||
static constexpr index_t kCNLane = 32;
|
||||
static constexpr index_t kCM0PerLane = 4;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
|
||||
else
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
c_vec = __builtin_amdgcn_mfma_i32_32x32x8i8(
|
||||
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
|
||||
#elif defined(__gfx908__) || defined(__gfx90a__)
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float a_f32 =
|
||||
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
|
||||
.template get_as<ADataType>()[number<k>{}]);
|
||||
float b_f32 =
|
||||
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
|
||||
.template get_as<BDataType>()[number<k>{}]);
|
||||
|
||||
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
|
||||
});
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
CVecType c_vec{0};
|
||||
operator()(c_vec, a_vec, b_vec);
|
||||
return c_vec;
|
||||
}
|
||||
};
|
||||
|
||||
#undef DISPATCH_MFMA_
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher;
|
||||
|
||||
// clang-format off
|
||||
// fp16
|
||||
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
|
||||
// bf16
|
||||
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
|
||||
// fp8
|
||||
template<> struct WarpGemmMfmaDispatcher<fp8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<fp8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<fp8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<fp8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<bf8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<bf8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<bf8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<bf8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
|
||||
|
||||
// clang-format on
|
||||
} // namespace impl
|
||||
|
||||
@@ -31,11 +31,21 @@ struct WarpGemmImpl
|
||||
using BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>;
|
||||
using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(CWarpTensor& c, const AWarpTensor& a, const BWarpTensor& b) const
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access()
|
||||
{
|
||||
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>;
|
||||
return WarpGemmAttribute_::get_num_of_access();
|
||||
}
|
||||
|
||||
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
@@ -44,18 +54,49 @@ struct WarpGemmImpl
|
||||
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
WarpGemmAttribute{}(c_vec, a_vec, b_vec);
|
||||
WarpGemmAttribute{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
|
||||
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(const AWarpTensor& a, const BWarpTensor& b) const
|
||||
template <typename CTensor,
|
||||
typename ATensor,
|
||||
typename BTensor,
|
||||
index_t i_subk,
|
||||
bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CTensor& c,
|
||||
const ATensor& a,
|
||||
const BTensor& b,
|
||||
number<i_subk>,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
CWarpTensor c;
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>;
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
|
||||
const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
|
||||
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
WarpGemmAttribute{}(c_vec, a_vec, b_vec, number<i_subk>{}, bool_constant<post_nop_>{});
|
||||
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
}
|
||||
|
||||
template <typename ATensor, typename BTensor>
|
||||
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
|
||||
{
|
||||
using CTensor = CWarpTensor;
|
||||
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
|
||||
CTensor c;
|
||||
|
||||
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
|
||||
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
|
||||
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
|
||||
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
Reference in New Issue
Block a user