This commit is contained in:
coderfeli
2025-04-11 02:37:49 +00:00
parent 867a4e527c
commit c2cdfda718
21 changed files with 1636 additions and 190 deletions

View File

@@ -533,11 +533,6 @@ include_directories(BEFORE
${HIP_INCLUDE_DIRS}
)
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
add_compile_options(-Weverything)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")

View File

@@ -66,7 +66,6 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt

View File

@@ -333,14 +333,8 @@ int main(int argc, char* argv[])
"not support this GEMM problem");
}
size_t total_size =
(M * K * sizeof(A0DataType) + N * K * sizeof(B0DataType) + M * sizeof(D0DataType) +
N * sizeof(D1DataType) + M * N * sizeof(EDataType));
int rotate_buf_num =
ck::math::min(size_t(Repeat), ck::math::integer_divide_ceil(512 * 1024 * 1024, total_size));
float ave_time = invoker.Run(
argument, StreamConfig{nullptr, time_kernel, 0, Warmup, Repeat, true, rotate_buf_num});
argument, StreamConfig{nullptr, time_kernel, 0, Warmup, Repeat, false, 1});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =

View File

@@ -21,7 +21,7 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt --receipt 200
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
@@ -45,7 +45,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 200
)
add_custom_command(

View File

@@ -3,11 +3,11 @@
# generate kernel instances to speed up compilation
FWD_DTYPE_MAP = {
"fp16" : "FmhaFwdFp16",
# "fp16" : "FmhaFwdFp16",
"bf16" : "FmhaFwdBf16",
"fp8" : "FmhaFwdFp8",
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16"
# "fp8" : "FmhaFwdFp8",
# "fp8fp16": "FmhaFwdFp8Fp16",
# "fp8bf16": "FmhaFwdFp8Bf16"
}
BWD_DTYPE_MAP = {

View File

@@ -118,7 +118,7 @@ FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
{F_inner_dispatch}
}}
"""
@@ -288,7 +288,7 @@ class FmhaFwdApiPool:
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
@@ -413,18 +413,17 @@ class FmhaFwdKernel:
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
# '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
}
else:
return None
@@ -441,7 +440,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
pipelines = []
if dtype in ['fp16', 'bf16']:
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
if hdim == 256:
if False:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
@@ -449,20 +448,24 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
else:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
else:
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
if receipt == 1 and bias != "bias":
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
# if bias == "bias":
# # TODO: rocm 6.2 compiler problem if using qr_async for bias case
# pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
# else:
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
# if receipt == 1 and bias != "bias":
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse/dropout kernels
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
@@ -490,10 +493,6 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
if hdim == 192 and tile.F_bn1 == 128:
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't' or (pipeline.F_mask not in ['no', 's_no']):
continue
k = FmhaFwdKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
@@ -534,6 +533,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
cond &= mode == 'group'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_bias == 'no'
cond &= pipeline.F_lse == 'f'
cond &= pipeline.F_dropout == 'f'
if not cond:
continue
api_pool.register_traits(k.api_trait())

View File

@@ -480,6 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
// std::vector<int32_t> page_idx_host(seqstart_k_host.back(), 0);
ck_tile::HostTensor<int32_t> page_idx_host({seqstart_k_host.back()});
// std::iota(page_idx_host.begin(), page_idx_host.end(), 0);
iota_shuffle(page_idx_host.mData.begin(), page_idx_host.mData.end(), 0);
// for (int i = 0; i < page_idx_host.get_element_space_size(); i++) {
// page_idx_host(i) = (i + 19) % page_idx_host.size();
// }
page_idx_host.savetxt("page_idx_host.txt", "int");
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
@@ -585,7 +593,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
};
bool is_v_rowmajor = vlayout == std::string("r");
assert(is_v_rowmajor);
assert(!i_perm);
// host memory for storing all the tensor elements
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
const ck_tile::index_t shape_seqlen_q =
@@ -601,6 +611,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
0 < page_block_size
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
ck_tile::HostTensor<KDataType> k_host_sgl({seqstart_k_host.back(), nhead_k, hdim_q});
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
ck_tile::HostTensor<KDataType> knew_host(
0 < seqlen_knew
@@ -613,6 +625,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
: get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size))
: (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)));
ck_tile::HostTensor<KDataType> v_host_sgl({seqstart_k_host.back(), nhead_k, hdim_v});
ck_tile::HostTensor<VDataType> vnew_host(
0 < seqlen_knew
? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v)
@@ -742,13 +755,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
}
k_host.ForEach([&](auto& self, auto i) {
k_host_sgl(page_idx_host(i[1]), i[2], i[3]) = self(i);
});
v_host.ForEach([&](auto& self, auto i) {
v_host_sgl(page_idx_host(i[1]), i[2], i[3]) = self(i);
});
// k_host.savetxt("k_host.txt");
// k_host_sgl.savetxt("k_host_sgl.txt");
// v_host_sgl.savetxt("v_host_sgl.txt");
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0);
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0);
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host_sgl.get_element_space_size_in_bytes());
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host_sgl.get_element_space_size_in_bytes());
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
@@ -757,6 +778,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem page_idx(page_idx_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
0 <= seqlen_kpads[0]
? seqlen_ks.size() * sizeof(int32_t)
@@ -773,14 +795,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
k_buf.ToDevice(k_host_sgl.data());
knew_buf.ToDevice(knew_host.data());
v_buf.ToDevice(v_host.data());
v_buf.ToDevice(v_host_sgl.data());
vnew_buf.ToDevice(vnew_host.data());
bias_buf.ToDevice(bias_host.data());
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
: seqstart_k_with_padding_host.data());
page_idx.ToDevice(page_idx_host.data());
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0]
? seqlen_ks.data()
: nullptr);
@@ -996,6 +1020,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
args.seqstart_k_ptr =
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
args.page_idx_ptr =
(mode == mode_enum::group ? page_idx.GetDeviceBuffer() : nullptr);
args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
? seqlen_k_buf.GetDeviceBuffer()
: nullptr);
@@ -1171,7 +1197,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto o_naive_ref = o_naive_buf.ToHost<ODataType>();
o_buf.FromDevice(o_host.data()); // TODO: ugly
auto [rtol_, atol_] = get_elimit<DataTypeConfig>(init_method);
bool pass_ = ck_tile::check_err(
o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_);
@@ -1513,6 +1538,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
// clang-format on
// o_host_result.savetxt("o_host_result.txt");
// o_host_ref.savetxt("o_host_ref.txt");
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
bool cur_pass = ck_tile::check_err(
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);

View File

@@ -129,7 +129,8 @@ struct fmha_fwd_args
const void* seqstart_k_ptr;
const void*
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
const void* page_idx_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
@@ -326,6 +327,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.page_idx_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,

View File

@@ -0,0 +1,316 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
using F16 = ck_tile::half_t;
// using BF16 = ck::bhalf_t;
using F8 = ck_tile::fp8_t;
using F32 = float;
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout>
float invoke_gemm(int n_warmup, int n_repeat, const moe_gemm_kargs& args)
{
float ave_time = moe_gemm<ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Moe Gemm"};
std::size_t flop = 0, num_btype = 0;
flop += std::size_t(2) * args.M * args.N * args.K;
num_btype += sizeof(ADataType) * args.M * args.K + sizeof(BDataType) * args.K * args.N +
sizeof(CDataType) * args.M * args.N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
return ave_time;
}
template <typename ALayout, typename BLayout, typename CLayout>
int run_moe_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
};
// auto valid_input_data = [&](int group_count, const auto&... args) {
// return !(args.empty() || ...) && group_count == (args.size() == ...);
// };
// ck_tile::index_t N = 4096;
// ck_tile::index_t K = 4096;
// ck_tile::index_t experts = 8;
// ck_tile::index_t sorted_tile_num = 8;
// ck_tile::index_t valid_tile_num = 8;
// ck_tile::index_t tokens = 128;
// ck_tile::index_t topk = 2;
const ck_tile::index_t N = arg_parser.get_int("N");
const ck_tile::index_t K = arg_parser.get_int("K");
ck_tile::index_t stride_A = arg_parser.get_int("stride_A");
ck_tile::index_t stride_B = arg_parser.get_int("stride_B");
ck_tile::index_t stride_C = arg_parser.get_int("stride_C");
const ck_tile::index_t num_tokens = arg_parser.get_int("NumTokens");
const ck_tile::index_t topk = arg_parser.get_int("TopK");
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
const ck_tile::index_t experts = arg_parser.get_int("experts");
// TODO: replace the magic declaration
const ck_tile::index_t MPerBlock = 128;
ck_tile::index_t sorted_tile_num = 8;
ck_tile::index_t valid_tile_num = sorted_tile_num;
const ck_tile::index_t M = sorted_tile_num * MPerBlock;
std::unique_ptr<ck_tile::DeviceMem> a_m_k_dev_buf;
std::unique_ptr<ck_tile::DeviceMem> b_k_n_dev_buf;
std::unique_ptr<ck_tile::DeviceMem> c_m_n_dev_buf;
stride_A = ck_tile::get_default_stride(M, N, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
auto a_m_k_tensor = ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
// TODO: add the experts' weights in b
auto b_k_n_tensor = ck_tile::HostTensor<BDataType>(
is_row_major(b_layout)
? ck_tile::host_tensor_descriptor(experts * K, N, stride_B, is_row_major(b_layout))
: ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
auto c_m_n_tensor = ck_tile::HostTensor<CDataType>(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
std::cout << "gemm"
<< " a_m_k: " << a_m_k_tensor.mDesc << " b_k_n: " << b_k_n_tensor.mDesc
<< " c_m_n: " << c_m_n_tensor.mDesc << std::endl;
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensor);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensor);
a_m_k_dev_buf =
std::make_unique<ck_tile::DeviceMem>(a_m_k_tensor.get_element_space_size_in_bytes());
b_k_n_dev_buf =
std::make_unique<ck_tile::DeviceMem>(b_k_n_tensor.get_element_space_size_in_bytes());
c_m_n_dev_buf =
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
a_m_k_dev_buf->ToDevice(a_m_k_tensor.data());
b_k_n_dev_buf->ToDevice(b_k_n_tensor.data());
c_m_n_dev_buf->SetZero();
c_m_n_tensor.SetZero();
const void* p_a = a_m_k_dev_buf->GetDeviceBuffer();
const void* p_b = b_k_n_dev_buf->GetDeviceBuffer();
void* p_c = c_m_n_dev_buf->GetDeviceBuffer();
// TODO: malloc and init sorted tokens and max tokens buffer
ck_tile::HostTensor<ck_tile::index_t> expert_ids(
ck_tile::HostTensorDescriptor({sorted_tile_num}, {1}));
ck_tile::HostTensor<ck_tile::index_t> sorted_token_ids(
ck_tile::HostTensorDescriptor({sorted_tile_num * MPerBlock}, {1}));
ck_tile::HostTensor<ck_tile::index_t> max_token_id(
ck_tile::HostTensorDescriptor({1 + sorted_tile_num}));
std::unique_ptr<ck_tile::DeviceMem> sorted_token_ids_dev = std::make_unique<ck_tile::DeviceMem>(
sizeof(ck_tile::index_t) * sorted_token_ids.get_element_space_size_in_bytes());
std::unique_ptr<ck_tile::DeviceMem> expert_ids_dev = std::make_unique<ck_tile::DeviceMem>(
sizeof(ck_tile::index_t) * expert_ids.get_element_space_size_in_bytes());
std::unique_ptr<ck_tile::DeviceMem> max_token_id_dev = std::make_unique<ck_tile::DeviceMem>(
sizeof(ck_tile::index_t) * max_token_id.get_element_space_size_in_bytes());
max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8};
int eids[] = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
}
int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;
// sorted_token_ids.mData[0] = 0;
for(int i = 0; i < sorted_tile_num * MPerBlock; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile && tokenid < num_tokens * topk)
{
sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = num_tokens - 1;
}
}
sorted_token_ids_dev->ToDevice(sorted_token_ids.data());
expert_ids_dev->ToDevice(expert_ids.data());
max_token_id_dev->ToDevice(max_token_id.data());
const ck_tile::index_t* p_sorted_token_ids_dev =
static_cast<ck_tile::index_t*>(sorted_token_ids_dev->GetDeviceBuffer());
const ck_tile::index_t* p_expert_ids_dev =
static_cast<ck_tile::index_t*>(expert_ids_dev->GetDeviceBuffer());
const ck_tile::index_t* p_max_token_id_dev =
static_cast<ck_tile::index_t*>(max_token_id_dev->GetDeviceBuffer());
moe_gemm_kargs gemm_desc{p_sorted_token_ids_dev,
p_expert_ids_dev,
p_max_token_id_dev,
p_a,
p_b,
p_c,
num_tokens,
topk,
M,
N,
K,
stride_A,
stride_B,
stride_C};
invoke_gemm<ALayout, BLayout, CLayout>(3, repeat, gemm_desc);
c_m_n_dev_buf->FromDevice(c_m_n_tensor.data());
bool pass{true};
if(arg_parser.get_int("validate"))
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
std::unique_ptr<ck_tile::DeviceMem> c_m_n_ref_buf =
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
c_m_n_ref_buf->SetZero();
ck_tile::reference_moe_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
p_sorted_token_ids_dev,
p_expert_ids_dev,
p_max_token_id_dev,
static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(c_m_n_ref_buf->GetDeviceBuffer()),
num_tokens,
topk,
M,
N,
K,
stride_A,
stride_B,
stride_C);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1 /*kbatch*/, max_accumulated_value);
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
// for(int im = 0; im < M; im++)
// {
// for(int in = 0; in < N; in++)
// {
// // if (static_cast<float>(static_cast<CDataType*>(p_c)[im * N + in]) != 0)
// printf("c[%d][%d]: %f ",
// im,
// in,
// static_cast<float>(static_cast<CDataType*>(p_c)[im * N + in]));
// printf("ref[%d][%d]: %f \n",
// im,
// in,
// static_cast<float>(
// static_cast<CDataType*>(c_m_n_host_ref.data())[im * N + in]));
// }
// }
pass = ck_tile::check_err(c_m_n_tensor,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}
int run_moe_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C")
{
return run_moe_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
}
// else if(a_layout == "R" && b_layout == "R")
// {
// return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
// }
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}

View File

@@ -18,32 +18,10 @@
namespace ck_tile {
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
template <typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
@@ -51,35 +29,11 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
}
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -104,6 +104,31 @@ CK_TILE_DEVICE void store_tile(
tile_window.store(dstr_tensor, number<-1>{});
}
// template <typename T,
// typename BottomTensorView_,
// typename WindowLengths_,
// typename TileDistribution_,
// typename DataType_>
// CK_TILE_DEVICE void
// store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
// const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
// const T& offsets)
// {
// using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
// using TileDstr = remove_cvref_t<TileDistribution_>;
// static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
// constexpr auto tile_dstr = TileDstr{};
// auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
// tile_window_tmp.get_window_lengths(),
// tile_window_tmp.get_window_origin(),
// tile_dstr);
// tile_window.store(dstr_tensor, offsets);
// }
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,

View File

@@ -0,0 +1,654 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
/**
* @brief This class provides tile (windowed) view and access to the device memory.
*
* @note This tile window does not support single issue you need to use tile_window_linear
* structure for this purpose
*
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
* @tparam NumCoord TBD
*/
template <typename BottomTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1,
index_t YsGatherDim = 0>
struct tile_scatter_gather
{
using BottomTensorView = remove_reference_t<BottomTensorView_>;
using WindowLengths = remove_cvref_t<WindowLengths_>;
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
using PageIdxArray = remove_cvref_t<StaticPageIndexArray_>;
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static_assert(NumCoord == 1);
// TODO: check WindowLengths and StaticTileDistribution are consistent
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
"wrong! lengths should be static");
static_assert(TileDstr::is_static(), "wrong!");
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
"wrong! inconsistent # of diemsnions");
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
using WindowAdaptorCoord =
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
using BottomTensorCoord =
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
struct load_store_traits
{
private:
static constexpr auto get_vector_dim_y_scalar_per_vector()
{
const auto [ys_vector_lengths, ys_vector_strides] =
tile_scatter_gather::
get_window_adaptor_ys_safe_vector_length_strides();
index_t VectorDimY_ = 0;
index_t ScalarPerVector_ = 1;
for(index_t i = 0; i < NDimY; ++i)
{
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
{
ScalarPerVector_ = ys_vector_lengths[i];
VectorDimY_ = i;
}
}
return make_tuple(VectorDimY_, ScalarPerVector_);
}
public:
static constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
static constexpr index_t ScalarPerVector =
get_vector_dim_y_scalar_per_vector().template at<1>();
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
// using vector_t = typename vector_type_t::type;
using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
private:
static constexpr auto scalars_per_access_ = [] {
constexpr auto scalars_per_access_arr = generate_array(
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr auto NDimY_ = NDimY;
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
}();
static constexpr auto get_space_filling_curve()
{
constexpr auto tile_dstr = TileDstr{};
constexpr auto thread_tensor_lengths_ys =
to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
// FIXME: need logic to judge dim access order
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
return space_filling_curve<decltype(thread_tensor_lengths_ys),
DimAccessOrder,
decltype(scalars_per_access_)>{};
}
public:
using SFC_Ys = decltype(get_space_filling_curve());
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
};
static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord;
CK_TILE_DEVICE constexpr tile_scatter_gather() = default;
CK_TILE_DEVICE constexpr tile_scatter_gather(
const BottomTensorView& bottom_tensor_view,
const WindowLengths& window_lengths,
const BottomTensorIndex& window_origin,
const TileDstr& tile_distribution,
const PageIdxArray& page_idx)
: bottom_tensor_view_{bottom_tensor_view},
window_lengths_{window_lengths},
window_origin_{window_origin},
tile_dstr_{tile_distribution},
page_idx_{page_idx},
pre_computed_coords_{}
{
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(detail::get_partition_index(tile_distribution),
array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
// BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
// tuple<index_t, index_t>(0, window_adaptor_thread_coord_tmp.get_bottom_index()[1]);
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using Traits = load_store_traits;
using SFC_Ys = typename Traits::SFC_Ys;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
}
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
{
return TileDstr::is_static();
}
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
CK_TILE_DEVICE constexpr void
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
{
bottom_tensor_view_.buf_.p_data_ = data;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template <typename ATopIndex>
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
WindowAdaptorCoord& window_adaptor_thread_coord,
BottomTensorCoord& bottom_tensor_thread_coord,
const ATopIndex& idx_diff_adaptor_top) const
{
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
window_adaptor_thread_coord,
idx_diff_adaptor_top,
idx_diff_adaptor_bottom);
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
bottom_tensor_thread_coord,
idx_diff_adaptor_bottom);
}
// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
{
// bottom tensor top dimension vector lengths and strides
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
// window vector lengths/strides
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
// window adaptor [p0, p1, ..., y0, y1, ...]
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
-1};
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
-1};
constexpr auto window_adaptor_bottom_dims =
WindowAdaptor::get_bottom_dimension_hidden_ids();
set_container_subset(window_adaptor_vector_lengths,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_lengths);
set_container_subset(window_adaptor_vector_strides,
window_adaptor_bottom_dims,
window_adaptor_bottom_dim_vector_strides);
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
window_adaptor_vector_lengths, window_adaptor_vector_strides);
// [y0, y1, ...]
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
NDimWindowAdaptorTop,
1>::type{};
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
}
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; }
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
constexpr auto tile_dstr = TileDstr{};
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
return dst_tensor;
}
template <typename DistributedTensor,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_m = idx_ys_start[number<YsGatherDim>{}];
const auto page_offset = page_idx_[idx_m];
// read from bottom tensor
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, page_offset, bool_constant<oob_conditional_check>{});
#if 1
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j / Traits::PackedSize];
});
#else
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % Traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
template <index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// printf("off %d\n", page_idx_[I0]);
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
// BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
// window_origin_ +
// tuple<index_t, index_t>(0, window_adaptor_thread_coord.get_bottom_index()[1]);
// auto bottom_tensor_thread_coord = make_tensor_coordinate(
// bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_m = idx_ys_start[number<0>{}];
const auto page_offset = page_idx_[idx_m];
// printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
// idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
// read from distributed tensor
// vector_type_t vec;
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
// printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j);
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// write into bottom tensor
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
vec_value,
bool_constant<oob_conditional_check>{});
// printf("coord_offset:%d, scatter_offset:%d \n",
// bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
{
window_origin_ += step;
BottomTensorIndex step_new = step;
step_new(HsGatherDim) = 0;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
pre_computed_coords_(iCoord)(I1),
step_new);
});
}
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx)
{
page_idx_ = new_idx;
// static_for<0, 2, 1>{}([&](auto k0) {
// printf("update tid %d %d \n", threadIdx.x, page_idx_[k0]);
// });
}
// CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
// {
// window_origin_ = new_window_origin;
// #if 0 // debug
// // TODO: this use more register for FA, but less register for GEMM
// // need investigation
// // only support warp-tile and block-tile
// static_assert(NDimP == 1 or NDimP == 2, "wrong!");
// WindowAdaptorCoord window_adaptor_thread_coord_tmp;
// if constexpr(NDimP == 1)
// {
// window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
// tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
// }
// else if constexpr(NDimP == 2)
// {
// window_adaptor_thread_coord_tmp =
// make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
// AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
// }
// #else
// // TODO: this use less register for FA, but more register for GEMM
// // need investigation
// const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
// tile_dstr_.get_ps_ys_to_xs_adaptor(),
// container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
// #endif
// BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
// window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
// const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
// bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// // future load/store() calls (might allocate more registers)
// using Traits = load_store_traits;
// using SFC_Ys = typename Traits::SFC_Ys;
// static_for<0, NumCoord, 1>{}([&](auto iCoord) {
// auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
// auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
// constexpr auto idx_diff_ys =
// SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
// constexpr auto idx_diff_ps_ys = container_concat(
// generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
// move_window_adaptor_and_bottom_tensor_thread_coordinate(
// window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
// pre_computed_coords_(iCoord) =
// make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
// });
// }
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView bottom_tensor_view_;
//
WindowLengths window_lengths_;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex window_origin_;
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr tile_dstr_;
PageIdxArray page_idx_;
// this contains:
// per-thread coordinate for window adaptor
// per-thread coordinate for bottom tensor
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
};
// TODO: use strategy
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1>
CK_TILE_DEVICE constexpr auto
make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx,
number<HsGatherDim> = {},
number<NumCoord> = {})
{
return tile_scatter_gather<remove_cvref_t<TensorView_>,
remove_cvref_t<WindowLengths_>,
remove_cvref_t<StaticTileDistribution_>,
remove_cvref_t<StaticPageIndexArray_>,
HsGatherDim,
NumCoord>{
tensor_view, window_lengths, origin, tile_distribution, page_idx};
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution, typename StaticPageIndexArray, index_t HsGatherDim>
CK_TILE_DEVICE constexpr auto
make_tile_scatter_gather(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const multi_index<TensorView::get_num_of_dimension()>& origin,
const StaticTileDistribution& tile_distribution,
const StaticPageIndexArray& page_idx,
number<HsGatherDim> = {})
{
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
origin,
tile_distribution,
page_idx,
number<HsGatherDim>{});
}
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution, typename StaticPageIndexArray, index_t HsGatherDim>
CK_TILE_DEVICE constexpr auto
make_tile_scatter_gather(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
const StaticTileDistribution& tile_distribution, const StaticPageIndexArray& page_idx,
number<HsGatherDim> = {})
{
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_distribution,
page_idx,
number<HsGatherDim>{});
}
// template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
// CK_TILE_DEVICE constexpr auto
// make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
// const StaticTileDistribution& tile_distribution)
// {
// auto w = make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
// tile_window.get_window_lengths(),
// tile_window.get_window_origin(),
// tile_distribution);
// w.init_raw();
// return w;
// }
} // namespace ck_tile

View File

@@ -609,6 +609,93 @@ struct tile_window_with_static_distribution
});
}
// template <typename statically_indexed_array,
// index_t i_access_unsupport_ = -1,
// bool oob_conditional_check = true>
// CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
// const statically_indexed_array& offsets,
// number<i_access_unsupport_> = {},
// bool_constant<oob_conditional_check> = {}) const
// {
// using Traits = load_store_traits;
// // using vector_type_t = typename Traits::vector_type_t;
// using vector_t = typename Traits::vector_t;
// using SFC_Ys = typename Traits::SFC_Ys;
// constexpr auto tile_dstr = TileDstr{};
// // loop over thread tensor space [y0, y1, ...]
// static_for<0, NumCoord, 1>{}([&](auto iCoord) {
// auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
// // auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
// BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
// window_origin_ +
// tuple<index_t, index_t>(0, window_adaptor_thread_coord.get_bottom_index()[1]);
// auto bottom_tensor_thread_coord = make_tensor_coordinate(
// bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
// static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
// constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// // data index [y0, y1, ...]
// constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// constexpr auto idx_m = idx_ys_start[number<0>{}];
// const auto offset = offsets[idx_m];
// // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
// // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
// // read from distributed tensor
// // vector_type_t vec;
// vector_t vec_value;
// static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
// constexpr auto idx_ys = generate_tuple(
// [&](auto jj) {
// return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
// : idx_ys_start[jj];
// },
// number<NDimY>{});
// constexpr index_t d =
// tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
// Traits::PackedSize;
// // printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j);
// vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
// dstr_tensor.get_thread_buffer().template at<d>();
// });
// // const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// // write into bottom tensor
// get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
// bottom_tensor_thread_coord,
// offset,
// vec_value,
// bool_constant<oob_conditional_check>{});
// // printf("coord_offset:%d, scatter_offset:%d \n",
// // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
// if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
// {
// constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
// constexpr auto forward_step_scatter = generate_tuple(
// [&](auto i) { return i == 0 ? 0 : idx_diff_ys[i]; }, number<NDimY>{});
// constexpr auto idx_diff_ps_ys = container_concat(
// generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
// forward_step_scatter);
// move_window_adaptor_and_bottom_tensor_thread_coordinate(
// window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
// }
// });
// });
// }
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
@@ -1010,23 +1097,6 @@ make_tile_window_raw(const TensorView_& tensor_view,
return w;
}
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord>
CK_TILE_DEVICE void move_tile_window(
tile_window_with_static_distribution<TensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>& window,
const typename tile_window_with_static_distribution<TensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>::BottomTensorIndex& step)
{
window.move(step);
}
/**
* @brief This class provides description of tile windowed view on the device memory.
*
@@ -1155,13 +1225,4 @@ make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLen
return w;
}
template <typename TensorView_, typename WindowLengths_>
CK_TILE_DEVICE void move_tile_window(
tile_window_with_static_lengths<TensorView_, WindowLengths_>& window,
const typename tile_window_with_static_lengths<TensorView_, WindowLengths_>::BottomTensorIndex&
step)
{
window.move(step);
}
} // namespace ck_tile

View File

@@ -1200,19 +1200,4 @@ make_tile_window_linear_raw(const TileWindow_& tile_window,
LinearBottomDims_{});
}
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename LinearBottomDims_>
CK_TILE_DEVICE void move_tile_window(
tile_window_linear<TensorView_, WindowLengths_, StaticTileDistribution_, LinearBottomDims_>&
window,
const typename tile_window_linear<TensorView_,
WindowLengths_,
StaticTileDistribution_,
LinearBottomDims_>::BottomTensorIndex& step)
{
window.move(step);
}
} // namespace ck_tile

View File

@@ -18,6 +18,14 @@
#pragma once
namespace ck_tile {
template <typename TileWindow_>
CK_TILE_DEVICE void move_tile_window(
TileWindow_& window,
const typename TileWindow_::BottomTensorIndex& step)
{
window.move(step);
}
// input a lds store tile, extract some information from it
// used to set m0 value for gfx9 serious
template <typename LdsTileWindow_>

View File

@@ -0,0 +1,236 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
// template <typename ADataType,
// typename BDataType,
// typename AccDataType,
// typename CDataType,
// typename AElementOp = ck_tile::identity,
// typename BElementOp = ck_tile::identity,
// typename ACCElementOp = ck_tile::identity>
// CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
// const HostTensor<BDataType>& b_k_n,
// HostTensor<CDataType>& c_m_n,
// const AElementOp& a_element_op = {},
// const BElementOp& b_element_op = {},
// const ACCElementOp& acc_element_op = {})
// {
// const std::size_t M = a_m_k.get_length(0);
// const std::size_t N = b_k_n.get_length(1);
// const std::size_t K = a_m_k.get_length(1);
// auto f_mn = [&](auto m, auto n) {
// AccDataType v_acc = 0;
// for(std::size_t k = 0; k < K; ++k)
// {
// AccDataType v_a;
// AccDataType v_b;
// if constexpr(std::is_same_v<ADataType, pk_int4_t>)
// {
// const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
// const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
// if(k % 2 == 1)
// v_a = fp32_val.hi;
// else
// v_a = fp32_val.lo;
// }
// else
// {
// v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
// }
// if constexpr(std::is_same_v<BDataType, pk_int4_t>)
// {
// const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
// const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
// if(k % 2 == 1)
// v_b = fp32_val.hi;
// else
// v_b = fp32_val.lo;
// }
// else
// {
// v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
// }
// v_acc += v_a * v_b;
// }
// c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
// };
// make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
// }
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
bool IsInputGemm = true>
__global__ void naive_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
const ck_tile::index_t* p_sorted_expert_ids_,
const ck_tile::index_t* p_max_token_id_,
const ADataType* A,
const BDataType* B,
CDataType* C,
ck_tile::index_t Num_tokens,
ck_tile::index_t TopK,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t strideA,
ck_tile::index_t strideB,
ck_tile::index_t strideC)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int row = idx / N; // Compute row index
int col = idx % N; // Compute column index
(void)Num_tokens;
// assert(p_sorted_expert_ids_ != nullptr);
// assert(TopK == 1);
// assert(Num_tokens == 128);
// if(Num_tokens == 128 && TopK == 1 && p_sorted_expert_ids_ != nullptr) {}
// index_t max_tokens = p_max_token_id_[0];
index_t gather_token_id = 0;
index_t scatter_token_id = 0;
index_t expert_id = 0;
if(row < p_max_token_id_[0])
{
expert_id = p_sorted_expert_ids_[row / 128];
gather_token_id = p_sorted_token_ids_[row] & 0xffffff;
scatter_token_id = p_sorted_token_ids_[row] & 0xffffff;
if(!IsInputGemm)
{
gather_token_id = gather_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
}
else
{
scatter_token_id = scatter_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
}
}
else
{
return;
}
if(row < M && col < N)
{
AccDataType acc = 0.0;
for(int k = 0; k < K; ++k)
{
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
// Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? gather_token_id * strideA + k
: k * strideA + gather_token_id;
// TODO: add experts weights dispatch
int b_index =
expert_id * N * K + ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col);
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
}
acc += v_a * v_b;
}
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? scatter_token_id * strideC + col
: col * strideC + scatter_token_id;
C[c_index] = ck_tile::type_convert<CDataType>(acc);
}
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
bool IsInputGemm = true>
void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
const index_t* p_sorted_expert_ids_,
const index_t* p_max_token_id_,
const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
index_t Num_tokens,
index_t TopK,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c)
{
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType,
BDataType,
AccDataType,
CDataType,
LayoutA,
LayoutB,
LayoutC,
IsInputGemm><<<numBlocks, numThreadsPerBlock>>>(p_sorted_token_ids_,
p_sorted_expert_ids_,
p_max_token_id_,
a_ptr,
b_ptr,
c_ptr,
Num_tokens,
TopK,
M,
N,
K,
stride_a,
stride_b,
stride_c);
return;
}
} // namespace ck_tile

View File

@@ -6,7 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
namespace ck_tile {
template <typename ADataType_,
@@ -119,6 +119,124 @@ struct CShuffleEpilogue
return kMWave * kNWave * kMPerXdl * kNPerXdl * sizeof(ODataType);
}
template <typename ODramWindow,
typename OAccTile,
bool IsInputGemm = true,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
void* p_smem,
const index_t* p_sorted_tokens_id,
index_t token_pos)
{
const index_t iMWarp = get_warp_id() / kNWave;
const index_t iNWarp = get_warp_id() - iMWarp * kNWave;
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
static_cast<ODataType*>(p_smem), lds_block_desc);
auto in_lds_window =
make_tile_window(o_lds_block,
make_tuple(number<kMPerXdl>{}, number<kNPerXdl>{}),
{number<kMPerXdl>{} * iMWarp, number<kNPerXdl>{} * iNWarp});
auto out_lds_window =
make_tile_window(o_lds_block,
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
{0, 0});
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<kMPerXdl * kMWave, kNPerXdl * kNWave>>;
constexpr index_t num_access = SFC::get_num_of_access();
using TileEncodingPattern =
TileDistributionEncodingPattern2D<kBlockSize,
kMPerIteration,
kNPerIteration,
GetVectorSizeC(),
tile_distribution_pattern::thread_raked>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
// auto coord = dram_tile_distribution.calculate_index();
// const auto& view = out_dram_window.get_bottom_tensor_view();
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto c_coord = dram_tile_distribution.calculate_index();
// printf("c_coord[0]:%d \n", c_coord[0]);
CWarpTensor c_warp_in_tensor;
static_for<0, num_access, 1>{}([&](auto iAccess) {
constexpr auto idx_y_start = SFC::get_index(iAccess);
// auto idx_m = number<idx_y_start.at(number<0>{})>{} + 0;
// printf("idx_y_start:%d \n", idx_m);
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (kMPerXdl * kMWave)>{};
statically_indexed_array<index_t, 2> offsets;
static_for<0, 2 /*CMrepeats*/, 1>{}([&](auto m0) {
auto token_id = token_pos + m0 + c_coord[0] + mIter * kMPerXdl * kMWave;
auto fused_token = p_sorted_tokens_id[token_id];
index_t token_offset = fused_token & 0xffffff;
if constexpr(IsInputGemm)
{
token_offset = token_offset * 3 /*TopK*/ + (fused_token >> 24);
}
offsets[m0] = token_offset * 4096; // Problem::kN_;
});
// printf("c_coord[number<0>{}]: %d \n", coord[number<0>{}]);
// printf("mIter: %d", mIter+0);
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (kNPerXdl * kNWave)>{};
// printf("mIter, nIter(%d, %d) \n", mIter+0, nIter+0);
c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
const auto c_warp_in_tensor_casted = cast_tile<ODataType>(c_warp_in_tensor);
block_sync_lds();
store_tile(in_lds_window, c_warp_in_tensor_casted);
block_sync_lds();
const auto c_out_tensor =
load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
if constexpr(out_memory_data_op == memory_operation_enum::set)
{
auto tile_window = make_tile_scatter_gather(out_dram_window.get_bottom_tensor_view(),
out_dram_window.get_window_lengths(),
out_dram_window.get_window_origin(),
dram_tile_distribution,
offsets);
tile_window.store(c_out_tensor);
// store_tile(out_dram_window, c_out_tensor, offsets);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(out_dram_window, {0, step.at(number<1>{})});
// printf("step.at(number<0>{}), step.at(number<1>{}):,%d, %d",
// step.at(number<0>{})+0, step.at(number<1>{})+0);
}
});
}
template <typename ODramWindow,
typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>

View File

@@ -265,6 +265,7 @@ struct FmhaFwdKernel
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
const int32_t* page_idx;
};
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
@@ -596,6 +597,7 @@ struct FmhaFwdKernel
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
const void* page_idx_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -654,7 +656,8 @@ struct FmhaFwdKernel
{}, // placeholder for dropout
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
reinterpret_cast<const int32_t*>(page_idx_ptr)};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
@@ -720,6 +723,7 @@ struct FmhaFwdKernel
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
const void* page_idx_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -758,6 +762,7 @@ struct FmhaFwdKernel
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
page_idx_ptr,
hdim_q,
hdim_v,
num_head_q,
@@ -799,6 +804,7 @@ struct FmhaFwdKernel
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
const void* page_idx_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -837,6 +843,7 @@ struct FmhaFwdKernel
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
page_idx_ptr,
hdim_q,
hdim_v,
num_head_q,
@@ -958,7 +965,7 @@ struct FmhaFwdKernel
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
long_index_t batch_offset_q = 0;
long_index_t batch_offset_k = 0;
// long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0;
@@ -972,7 +979,7 @@ struct FmhaFwdKernel
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
// batch_offset_k = key_start * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
@@ -981,6 +988,8 @@ struct FmhaFwdKernel
{
batch_offset_v = key_start;
}
kargs.page_idx += key_start;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias + key_start;
@@ -1016,35 +1025,38 @@ struct FmhaFwdKernel
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
}
}
else
{
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
}
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
if constexpr(kHasDropout)
{
batch_offset_randval =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
}
// else
// {
// batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
// batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
// batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
// if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
// {
// batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
// }
// if constexpr(kStoreLSE)
// {
// batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
// }
// if constexpr(kHasDropout)
// {
// batch_offset_randval =
// static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
// }
// batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
// }
// for simplicity, batch stride we just modify the pointer
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
batch_offset_q;
// const KDataType* k_ptr =
// reinterpret_cast<const KDataType*>(kargs.k_ptr) +
// static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
// batch_offset_k;
const KDataType* k_ptr =
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
batch_offset_k;
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k;
const VDataType* v_ptr =
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
@@ -1329,6 +1341,9 @@ struct FmhaFwdKernel
position_encoding,
kargs.scale_s,
smem_ptr,
kargs.page_idx,
kargs.stride_k,
kargs.stride_v,
dropout);
}
else
@@ -1343,6 +1358,9 @@ struct FmhaFwdKernel
position_encoding,
kargs.scale_s,
smem_ptr,
kargs.page_idx,
kargs.stride_k,
kargs.stride_v,
dropout);
}
}();

View File

@@ -8,6 +8,9 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
// #include "ck_tile/core/tensor/tile_scatter_gather_debug.hpp"
namespace ck_tile {
@@ -45,6 +48,10 @@ struct BlockFmhaPipelineQRKSVS
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
@@ -148,6 +155,9 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
const index_t* page_idx,
const index_t stride_k,
const index_t stride_v,
DropoutType& dropout) const
{
static_assert(
@@ -241,11 +251,10 @@ struct BlockFmhaPipelineQRKSVS
return o_acc;
}
}
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
{seqlen_k_start, 0}); //todo fixme felix
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
@@ -257,11 +266,28 @@ struct BlockFmhaPipelineQRKSVS
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
auto v_coord = v_dist.calculate_index();
const auto VPageIndexDim = I1;
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3];
statically_indexed_array<index_t, V_KRepeat> v_offsets;
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v;
// printf("1tid %d %d %d %d %d\n", threadIdx.x, v_coord[VPageIndexDim], k0.value, page_idx[v_coord[VPageIndexDim] + k0.value], stride_v);
});
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
v_dist,
v_offsets,
VPageIndexDim);
// auto v_dram_window =
// make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
// v_dram_block_window_tmp.get_window_lengths(),
// {0, seqlen_k_start}, // TODO: hdim split?
// v_dist);
auto q_tile = tile_elementwise_in(q_element_func, q);
@@ -275,12 +301,20 @@ struct BlockFmhaPipelineQRKSVS
do
{
// STAGE 1, QK gemm
auto k_dram_window = make_tile_window(
auto k_dist = Policy::template MakeKDramTileDistribution<Problem>();
auto k_coord = k_dist.calculate_index();
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
statically_indexed_array<index_t, NRepeat> k_offsets;
static_for<0, NRepeat, 1>{}([&](auto n0) {
k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
});
auto k_dram_window = make_tile_scatter_gather(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
k_dist,
k_offsets); // K DRAM tile window for
auto k_block_tile = load_tile(k_dram_window);
{
@@ -321,7 +355,13 @@ struct BlockFmhaPipelineQRKSVS
});
}
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
const auto v_prefetch = v_dram_window.load(); // prefetch load v tile
// const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v;
// printf("2tid %d %d %d %d\n", threadIdx.x, v_coord[VPageIndexDim], kK1 + v_coord[VPageIndexDim] + k0.value, page_idx[kK1 + v_coord[VPageIndexDim] + k0.value]);
});
v_dram_window.update_page_idx(v_offsets);
{ // tail
block_sync_lds();
gemm_0(s_acc,
@@ -523,6 +563,12 @@ struct BlockFmhaPipelineQRKSVS
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto v = load_tile(v_dram_window); // load next v
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v;
// printf("3tid %d %d %d %d\n", threadIdx.x, v_coord[VPageIndexDim], kK1 * 2 + i_k1.value * kK1 + v_coord[VPageIndexDim] + k0.value, page_idx[kK1 + i_k1.value * kK1 + v_coord[VPageIndexDim] + k0.value]);
});
v_dram_window.update_page_idx(v_offsets);
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
@@ -556,6 +602,7 @@ struct BlockFmhaPipelineQRKSVS
v_lds_window);
block_sync_lds();
}
page_idx += kN0;
} while(++i_total_loops < num_total_loop);
// store lse
@@ -626,6 +673,9 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
const index_t* page_idx,
const index_t stride_k,
const index_t stride_v,
DropoutType& dropout) const
{
return operator()(q_dram_block_window_tmp,
@@ -646,6 +696,9 @@ struct BlockFmhaPipelineQRKSVS
position_encoding,
scale_s,
smem_ptr,
page_idx,
stride_k,
stride_v,
dropout);
}
};

View File

@@ -623,7 +623,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t kKPack = GetSmemKPackV<Problem>(); //
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
@@ -783,19 +783,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t N1 = GetAlignmentV<Problem>(); // 8
constexpr index_t N0 = kNPerBlock / N1; // P // 16
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t K3 = total_pixels / N1; //2
constexpr index_t kKPack = GetSmemKPackV<Problem>(); // 8
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K2 = kKPack / K3; // //4 TODO: this dimention could be outside single wave
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t K1 = get_warp_size() / (K2 * N0); // 2
constexpr index_t K0 = kBlockSize / get_warp_size(); // 2
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,

View File

@@ -19,7 +19,6 @@ cmake
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS=$GPU_TARGETS \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \