mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
merge pa
This commit is contained in:
@@ -533,11 +533,6 @@ include_directories(BEFORE
|
||||
${HIP_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
|
||||
if(BUILD_DEV)
|
||||
add_compile_options(-Werror)
|
||||
add_compile_options(-Weverything)
|
||||
endif()
|
||||
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
||||
|
||||
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
|
||||
|
||||
@@ -66,7 +66,6 @@ else()
|
||||
-Wunreachable-code
|
||||
-Wunused
|
||||
-Wno-reserved-identifier
|
||||
-Werror
|
||||
-Wno-option-ignored
|
||||
-Wsign-compare
|
||||
-Wno-extra-semi-stmt
|
||||
|
||||
@@ -333,14 +333,8 @@ int main(int argc, char* argv[])
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
size_t total_size =
|
||||
(M * K * sizeof(A0DataType) + N * K * sizeof(B0DataType) + M * sizeof(D0DataType) +
|
||||
N * sizeof(D1DataType) + M * N * sizeof(EDataType));
|
||||
int rotate_buf_num =
|
||||
ck::math::min(size_t(Repeat), ck::math::integer_divide_ceil(512 * 1024 * 1024, total_size));
|
||||
|
||||
float ave_time = invoker.Run(
|
||||
argument, StreamConfig{nullptr, time_kernel, 0, Warmup, Repeat, true, rotate_buf_num});
|
||||
argument, StreamConfig{nullptr, time_kernel, 0, Warmup, Repeat, false, 1});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
|
||||
@@ -21,7 +21,7 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
|
||||
# generate a list of kernels, but not actually emit files at config sta
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt --receipt 200
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
@@ -45,7 +45,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 200
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
FWD_DTYPE_MAP = {
|
||||
"fp16" : "FmhaFwdFp16",
|
||||
# "fp16" : "FmhaFwdFp16",
|
||||
"bf16" : "FmhaFwdBf16",
|
||||
"fp8" : "FmhaFwdFp8",
|
||||
"fp8fp16": "FmhaFwdFp8Fp16",
|
||||
"fp8bf16": "FmhaFwdFp8Bf16"
|
||||
# "fp8" : "FmhaFwdFp8",
|
||||
# "fp8fp16": "FmhaFwdFp8Fp16",
|
||||
# "fp8bf16": "FmhaFwdFp8Bf16"
|
||||
}
|
||||
|
||||
BWD_DTYPE_MAP = {
|
||||
|
||||
@@ -118,7 +118,7 @@ FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
@@ -288,7 +288,7 @@ class FmhaFwdApiPool:
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
@@ -413,18 +413,17 @@ class FmhaFwdKernel:
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
# '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@@ -441,7 +440,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
if False:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
@@ -449,20 +448,24 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
# if bias == "bias":
|
||||
# # TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# else:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# if receipt == 1 and bias != "bias":
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
@@ -490,10 +493,6 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't' or (pipeline.F_mask not in ['no', 's_no']):
|
||||
continue
|
||||
k = FmhaFwdKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
@@ -534,6 +533,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_bias == 'no'
|
||||
cond &= pipeline.F_lse == 'f'
|
||||
cond &= pipeline.F_dropout == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_traits(k.api_trait())
|
||||
|
||||
@@ -480,6 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
|
||||
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
|
||||
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
|
||||
// std::vector<int32_t> page_idx_host(seqstart_k_host.back(), 0);
|
||||
ck_tile::HostTensor<int32_t> page_idx_host({seqstart_k_host.back()});
|
||||
// std::iota(page_idx_host.begin(), page_idx_host.end(), 0);
|
||||
iota_shuffle(page_idx_host.mData.begin(), page_idx_host.mData.end(), 0);
|
||||
// for (int i = 0; i < page_idx_host.get_element_space_size(); i++) {
|
||||
// page_idx_host(i) = (i + 19) % page_idx_host.size();
|
||||
// }
|
||||
page_idx_host.savetxt("page_idx_host.txt", "int");
|
||||
|
||||
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
|
||||
|
||||
@@ -585,7 +593,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
};
|
||||
|
||||
bool is_v_rowmajor = vlayout == std::string("r");
|
||||
|
||||
assert(is_v_rowmajor);
|
||||
assert(!i_perm);
|
||||
|
||||
// host memory for storing all the tensor elements
|
||||
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
|
||||
const ck_tile::index_t shape_seqlen_q =
|
||||
@@ -601,6 +611,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
0 < page_block_size
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
ck_tile::HostTensor<KDataType> k_host_sgl({seqstart_k_host.back(), nhead_k, hdim_q});
|
||||
|
||||
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
|
||||
ck_tile::HostTensor<KDataType> knew_host(
|
||||
0 < seqlen_knew
|
||||
@@ -613,6 +625,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size))
|
||||
: (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)));
|
||||
ck_tile::HostTensor<KDataType> v_host_sgl({seqstart_k_host.back(), nhead_k, hdim_v});
|
||||
ck_tile::HostTensor<VDataType> vnew_host(
|
||||
0 < seqlen_knew
|
||||
? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v)
|
||||
@@ -742,13 +755,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
}
|
||||
k_host.ForEach([&](auto& self, auto i) {
|
||||
k_host_sgl(page_idx_host(i[1]), i[2], i[3]) = self(i);
|
||||
});
|
||||
v_host.ForEach([&](auto& self, auto i) {
|
||||
v_host_sgl(page_idx_host(i[1]), i[2], i[3]) = self(i);
|
||||
});
|
||||
// k_host.savetxt("k_host.txt");
|
||||
// k_host_sgl.savetxt("k_host_sgl.txt");
|
||||
// v_host_sgl.savetxt("v_host_sgl.txt");
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0);
|
||||
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0);
|
||||
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host_sgl.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host_sgl.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
|
||||
@@ -757,6 +778,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem page_idx(page_idx_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
|
||||
0 <= seqlen_kpads[0]
|
||||
? seqlen_ks.size() * sizeof(int32_t)
|
||||
@@ -773,14 +795,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
k_buf.ToDevice(k_host_sgl.data());
|
||||
knew_buf.ToDevice(knew_host.data());
|
||||
v_buf.ToDevice(v_host.data());
|
||||
v_buf.ToDevice(v_host_sgl.data());
|
||||
vnew_buf.ToDevice(vnew_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
|
||||
: seqstart_k_with_padding_host.data());
|
||||
page_idx.ToDevice(page_idx_host.data());
|
||||
|
||||
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || 0 <= seqlen_kpads[0]
|
||||
? seqlen_ks.data()
|
||||
: nullptr);
|
||||
@@ -996,6 +1020,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(mode == mode_enum::group ? seqstart_q.GetDeviceBuffer() : nullptr);
|
||||
args.seqstart_k_ptr =
|
||||
(mode == mode_enum::group ? seqstart_k.GetDeviceBuffer() : nullptr);
|
||||
args.page_idx_ptr =
|
||||
(mode == mode_enum::group ? page_idx.GetDeviceBuffer() : nullptr);
|
||||
args.seqlen_k_ptr = ((mode == mode_enum::batch && use_kvcache) || 0 <= k_paddings_[0]
|
||||
? seqlen_k_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
@@ -1171,7 +1197,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
auto o_naive_ref = o_naive_buf.ToHost<ODataType>();
|
||||
o_buf.FromDevice(o_host.data()); // TODO: ugly
|
||||
|
||||
auto [rtol_, atol_] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool pass_ = ck_tile::check_err(
|
||||
o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_);
|
||||
@@ -1513,6 +1538,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
// clang-format on
|
||||
|
||||
// o_host_result.savetxt("o_host_result.txt");
|
||||
// o_host_ref.savetxt("o_host_ref.txt");
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool cur_pass = ck_tile::check_err(
|
||||
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
|
||||
@@ -129,7 +129,8 @@ struct fmha_fwd_args
|
||||
const void* seqstart_k_ptr;
|
||||
const void*
|
||||
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
|
||||
|
||||
const void* page_idx_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t batch;
|
||||
@@ -326,6 +327,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.page_idx_ptr,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
|
||||
316
example/ck_tile/XX_moe_gemm/run_moe_gemm_example.inc
Normal file
316
example/ck_tile/XX_moe_gemm/run_moe_gemm_example.inc
Normal file
@@ -0,0 +1,316 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
// using BF16 = ck::bhalf_t;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using F32 = float;
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float invoke_gemm(int n_warmup, int n_repeat, const moe_gemm_kargs& args)
|
||||
{
|
||||
float ave_time = moe_gemm<ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Moe Gemm"};
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
|
||||
flop += std::size_t(2) * args.M * args.N * args.K;
|
||||
|
||||
num_btype += sizeof(ADataType) * args.M * args.K + sizeof(BDataType) * args.K * args.N +
|
||||
sizeof(CDataType) * args.M * args.N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
int run_moe_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
// auto valid_input_data = [&](int group_count, const auto&... args) {
|
||||
// return !(args.empty() || ...) && group_count == (args.size() == ...);
|
||||
// };
|
||||
|
||||
// ck_tile::index_t N = 4096;
|
||||
// ck_tile::index_t K = 4096;
|
||||
// ck_tile::index_t experts = 8;
|
||||
// ck_tile::index_t sorted_tile_num = 8;
|
||||
// ck_tile::index_t valid_tile_num = 8;
|
||||
// ck_tile::index_t tokens = 128;
|
||||
// ck_tile::index_t topk = 2;
|
||||
|
||||
const ck_tile::index_t N = arg_parser.get_int("N");
|
||||
const ck_tile::index_t K = arg_parser.get_int("K");
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_A");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_B");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_C");
|
||||
const ck_tile::index_t num_tokens = arg_parser.get_int("NumTokens");
|
||||
const ck_tile::index_t topk = arg_parser.get_int("TopK");
|
||||
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
|
||||
const ck_tile::index_t experts = arg_parser.get_int("experts");
|
||||
|
||||
// TODO: replace the magic declaration
|
||||
const ck_tile::index_t MPerBlock = 128;
|
||||
ck_tile::index_t sorted_tile_num = 8;
|
||||
ck_tile::index_t valid_tile_num = sorted_tile_num;
|
||||
|
||||
const ck_tile::index_t M = sorted_tile_num * MPerBlock;
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> a_m_k_dev_buf;
|
||||
std::unique_ptr<ck_tile::DeviceMem> b_k_n_dev_buf;
|
||||
std::unique_ptr<ck_tile::DeviceMem> c_m_n_dev_buf;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, N, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
auto a_m_k_tensor = ck_tile::HostTensor<ADataType>(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
|
||||
// TODO: add the experts' weights in b
|
||||
auto b_k_n_tensor = ck_tile::HostTensor<BDataType>(
|
||||
is_row_major(b_layout)
|
||||
? ck_tile::host_tensor_descriptor(experts * K, N, stride_B, is_row_major(b_layout))
|
||||
: ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
|
||||
auto c_m_n_tensor = ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
std::cout << "gemm"
|
||||
<< " a_m_k: " << a_m_k_tensor.mDesc << " b_k_n: " << b_k_n_tensor.mDesc
|
||||
<< " c_m_n: " << c_m_n_tensor.mDesc << std::endl;
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensor);
|
||||
|
||||
a_m_k_dev_buf =
|
||||
std::make_unique<ck_tile::DeviceMem>(a_m_k_tensor.get_element_space_size_in_bytes());
|
||||
b_k_n_dev_buf =
|
||||
std::make_unique<ck_tile::DeviceMem>(b_k_n_tensor.get_element_space_size_in_bytes());
|
||||
c_m_n_dev_buf =
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf->ToDevice(a_m_k_tensor.data());
|
||||
b_k_n_dev_buf->ToDevice(b_k_n_tensor.data());
|
||||
c_m_n_dev_buf->SetZero();
|
||||
c_m_n_tensor.SetZero();
|
||||
|
||||
const void* p_a = a_m_k_dev_buf->GetDeviceBuffer();
|
||||
const void* p_b = b_k_n_dev_buf->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf->GetDeviceBuffer();
|
||||
|
||||
// TODO: malloc and init sorted tokens and max tokens buffer
|
||||
|
||||
ck_tile::HostTensor<ck_tile::index_t> expert_ids(
|
||||
ck_tile::HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> sorted_token_ids(
|
||||
ck_tile::HostTensorDescriptor({sorted_tile_num * MPerBlock}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> max_token_id(
|
||||
ck_tile::HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> sorted_token_ids_dev = std::make_unique<ck_tile::DeviceMem>(
|
||||
sizeof(ck_tile::index_t) * sorted_token_ids.get_element_space_size_in_bytes());
|
||||
std::unique_ptr<ck_tile::DeviceMem> expert_ids_dev = std::make_unique<ck_tile::DeviceMem>(
|
||||
sizeof(ck_tile::index_t) * expert_ids.get_element_space_size_in_bytes());
|
||||
std::unique_ptr<ck_tile::DeviceMem> max_token_id_dev = std::make_unique<ck_tile::DeviceMem>(
|
||||
sizeof(ck_tile::index_t) * max_token_id.get_element_space_size_in_bytes());
|
||||
|
||||
max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8};
|
||||
int eids[] = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = eids[i];
|
||||
}
|
||||
int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
// sorted_token_ids.mData[0] = 0;
|
||||
for(int i = 0; i < sorted_tile_num * MPerBlock; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile && tokenid < num_tokens * topk)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = num_tokens - 1;
|
||||
}
|
||||
}
|
||||
|
||||
sorted_token_ids_dev->ToDevice(sorted_token_ids.data());
|
||||
expert_ids_dev->ToDevice(expert_ids.data());
|
||||
max_token_id_dev->ToDevice(max_token_id.data());
|
||||
|
||||
const ck_tile::index_t* p_sorted_token_ids_dev =
|
||||
static_cast<ck_tile::index_t*>(sorted_token_ids_dev->GetDeviceBuffer());
|
||||
const ck_tile::index_t* p_expert_ids_dev =
|
||||
static_cast<ck_tile::index_t*>(expert_ids_dev->GetDeviceBuffer());
|
||||
const ck_tile::index_t* p_max_token_id_dev =
|
||||
static_cast<ck_tile::index_t*>(max_token_id_dev->GetDeviceBuffer());
|
||||
|
||||
moe_gemm_kargs gemm_desc{p_sorted_token_ids_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
num_tokens,
|
||||
topk,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C};
|
||||
|
||||
invoke_gemm<ALayout, BLayout, CLayout>(3, repeat, gemm_desc);
|
||||
|
||||
c_m_n_dev_buf->FromDevice(c_m_n_tensor.data());
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("validate"))
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> c_m_n_ref_buf =
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
|
||||
|
||||
c_m_n_ref_buf->SetZero();
|
||||
|
||||
ck_tile::reference_moe_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
p_sorted_token_ids_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<CDataType*>(c_m_n_ref_buf->GetDeviceBuffer()),
|
||||
num_tokens,
|
||||
topk,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1 /*kbatch*/, max_accumulated_value);
|
||||
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
|
||||
|
||||
// for(int im = 0; im < M; im++)
|
||||
// {
|
||||
// for(int in = 0; in < N; in++)
|
||||
// {
|
||||
// // if (static_cast<float>(static_cast<CDataType*>(p_c)[im * N + in]) != 0)
|
||||
// printf("c[%d][%d]: %f ",
|
||||
// im,
|
||||
// in,
|
||||
// static_cast<float>(static_cast<CDataType*>(p_c)[im * N + in]));
|
||||
// printf("ref[%d][%d]: %f \n",
|
||||
// im,
|
||||
// in,
|
||||
// static_cast<float>(
|
||||
// static_cast<CDataType*>(c_m_n_host_ref.data())[im * N + in]));
|
||||
// }
|
||||
// }
|
||||
|
||||
pass = ck_tile::check_err(c_m_n_tensor,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int run_moe_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
// else if(a_layout == "R" && b_layout == "R")
|
||||
// {
|
||||
// return run_grouped_gemm_example_with_layouts(argc, argv, Row{}, Row{}, Row{});
|
||||
// }
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
}
|
||||
@@ -18,32 +18,10 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
template <typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
@@ -51,35 +29,11 @@ CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
|
||||
}
|
||||
|
||||
template <typename DistributedTensor_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
index_t NumCoord,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
|
||||
const tile_window_with_static_distribution<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
NumCoord>& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor_,
|
||||
typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
typename LinearBottomDims_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
|
||||
const tile_window_linear<BottomTensorView_,
|
||||
WindowLengths_,
|
||||
TileDistribution_,
|
||||
LinearBottomDims_>& tile_window,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -104,6 +104,31 @@ CK_TILE_DEVICE void store_tile(
|
||||
tile_window.store(dstr_tensor, number<-1>{});
|
||||
}
|
||||
|
||||
// template <typename T,
|
||||
// typename BottomTensorView_,
|
||||
// typename WindowLengths_,
|
||||
// typename TileDistribution_,
|
||||
// typename DataType_>
|
||||
// CK_TILE_DEVICE void
|
||||
// store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& tile_window_tmp,
|
||||
// const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
|
||||
// const T& offsets)
|
||||
// {
|
||||
// using DataType = remove_cvref_t<typename BottomTensorView_::DataType>;
|
||||
// using TileDstr = remove_cvref_t<TileDistribution_>;
|
||||
|
||||
// static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
|
||||
|
||||
// constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
// auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
|
||||
// tile_window_tmp.get_window_lengths(),
|
||||
// tile_window_tmp.get_window_origin(),
|
||||
// tile_dstr);
|
||||
|
||||
// tile_window.store(dstr_tensor, offsets);
|
||||
// }
|
||||
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename TileDistribution_,
|
||||
|
||||
654
include/ck_tile/core/tensor/tile_scatter_gather.hpp
Normal file
654
include/ck_tile/core/tensor/tile_scatter_gather.hpp
Normal file
@@ -0,0 +1,654 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/utility.hpp"
|
||||
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/array.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
|
||||
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief This class provides tile (windowed) view and access to the device memory.
|
||||
*
|
||||
* @note This tile window does not support single issue you need to use tile_window_linear
|
||||
* structure for this purpose
|
||||
*
|
||||
* @tparam BottomTensorView_ Class describing & holding device tensor memory.
|
||||
* @tparam WindowLengths_ Spatial sizes of windowed view on tensor.
|
||||
* @tparam StaticTileDistribution_ Thread distribution (mapping) into Tile dimensions
|
||||
* @tparam NumCoord TBD
|
||||
*/
|
||||
template <typename BottomTensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename StaticPageIndexArray_,
|
||||
index_t HsGatherDim = 0,
|
||||
index_t NumCoord = 1,
|
||||
index_t YsGatherDim = 0>
|
||||
struct tile_scatter_gather
|
||||
{
|
||||
using BottomTensorView = remove_reference_t<BottomTensorView_>;
|
||||
using WindowLengths = remove_cvref_t<WindowLengths_>;
|
||||
using TileDstr = remove_cvref_t<StaticTileDistribution_>;
|
||||
using PageIdxArray = remove_cvref_t<StaticPageIndexArray_>;
|
||||
using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
|
||||
using BottomTensorDesc = typename BottomTensorView::TensorDesc;
|
||||
|
||||
using DataType = remove_cvref_t<typename BottomTensorView::DataType>;
|
||||
|
||||
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
|
||||
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
|
||||
|
||||
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
|
||||
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static_assert(NumCoord == 1);
|
||||
|
||||
// TODO: check WindowLengths and StaticTileDistribution are consistent
|
||||
|
||||
static_assert(ck_tile::is_known_at_compile_time<WindowLengths>::value,
|
||||
"wrong! lengths should be static");
|
||||
static_assert(TileDstr::is_static(), "wrong!");
|
||||
|
||||
static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
|
||||
"wrong! inconsistent # of diemsnions");
|
||||
|
||||
using AdaptorTopIndex = array<index_t, NDimWindowAdaptorTop>;
|
||||
using BottomTensorIndex = array<index_t, NDimBottomTensor>;
|
||||
|
||||
using WindowAdaptorCoord =
|
||||
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{}));
|
||||
|
||||
using BottomTensorCoord =
|
||||
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{}));
|
||||
|
||||
struct load_store_traits
|
||||
{
|
||||
private:
|
||||
static constexpr auto get_vector_dim_y_scalar_per_vector()
|
||||
{
|
||||
const auto [ys_vector_lengths, ys_vector_strides] =
|
||||
tile_scatter_gather::
|
||||
get_window_adaptor_ys_safe_vector_length_strides();
|
||||
|
||||
index_t VectorDimY_ = 0;
|
||||
index_t ScalarPerVector_ = 1;
|
||||
|
||||
for(index_t i = 0; i < NDimY; ++i)
|
||||
{
|
||||
if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
|
||||
{
|
||||
ScalarPerVector_ = ys_vector_lengths[i];
|
||||
VectorDimY_ = i;
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(VectorDimY_, ScalarPerVector_);
|
||||
}
|
||||
|
||||
public:
|
||||
static constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
|
||||
static constexpr index_t ScalarPerVector =
|
||||
get_vector_dim_y_scalar_per_vector().template at<1>();
|
||||
|
||||
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
|
||||
// using vector_t = typename vector_type_t::type;
|
||||
using vector_t = thread_buffer<DataType, ScalarPerVector / PackedSize>;
|
||||
|
||||
private:
|
||||
static constexpr auto scalars_per_access_ = [] {
|
||||
constexpr auto scalars_per_access_arr = generate_array(
|
||||
[&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
|
||||
|
||||
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
|
||||
constexpr auto NDimY_ = NDimY;
|
||||
|
||||
return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
|
||||
}();
|
||||
|
||||
static constexpr auto get_space_filling_curve()
|
||||
{
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
constexpr auto thread_tensor_lengths_ys =
|
||||
to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
// FIXME: need logic to judge dim access order
|
||||
using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
|
||||
|
||||
return space_filling_curve<decltype(thread_tensor_lengths_ys),
|
||||
DimAccessOrder,
|
||||
decltype(scalars_per_access_)>{};
|
||||
}
|
||||
|
||||
public:
|
||||
using SFC_Ys = decltype(get_space_filling_curve());
|
||||
|
||||
static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
|
||||
|
||||
static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
|
||||
static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
|
||||
};
|
||||
|
||||
static constexpr index_t NumAccessPerCoord = load_store_traits::NumAccess / NumCoord;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_scatter_gather() = default;
|
||||
|
||||
CK_TILE_DEVICE constexpr tile_scatter_gather(
|
||||
const BottomTensorView& bottom_tensor_view,
|
||||
const WindowLengths& window_lengths,
|
||||
const BottomTensorIndex& window_origin,
|
||||
const TileDstr& tile_distribution,
|
||||
const PageIdxArray& page_idx)
|
||||
: bottom_tensor_view_{bottom_tensor_view},
|
||||
window_lengths_{window_lengths},
|
||||
window_origin_{window_origin},
|
||||
tile_dstr_{tile_distribution},
|
||||
page_idx_{page_idx},
|
||||
pre_computed_coords_{}
|
||||
{
|
||||
#if 0 // debug
|
||||
// TODO: this use more register for FA, but less register for GEMM
|
||||
// need investigation
|
||||
// only support warp-tile and block-tile
|
||||
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
|
||||
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
|
||||
|
||||
if constexpr(NDimP == 1)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
|
||||
}
|
||||
else if constexpr(NDimP == 2)
|
||||
{
|
||||
window_adaptor_thread_coord_tmp =
|
||||
make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
|
||||
}
|
||||
#else
|
||||
// TODO: this use less register for FA, but more register for GEMM
|
||||
// need investigation
|
||||
const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
tile_distribution.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(detail::get_partition_index(tile_distribution),
|
||||
array<index_t, NDimY>{0}));
|
||||
#endif
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
|
||||
// BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
// tuple<index_t, index_t>(0, window_adaptor_thread_coord_tmp.get_bottom_index()[1]);
|
||||
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
|
||||
// future load/store() calls (might allocate more registers)
|
||||
using Traits = load_store_traits;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
|
||||
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
|
||||
|
||||
constexpr auto idx_diff_ys =
|
||||
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
pre_computed_coords_(iCoord) =
|
||||
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
|
||||
|
||||
CK_TILE_DEVICE static constexpr bool has_static_tile_distribution()
|
||||
{
|
||||
return TileDstr::is_static();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
|
||||
|
||||
CK_TILE_DEVICE constexpr void
|
||||
set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
|
||||
{
|
||||
bottom_tensor_view_.buf_.p_data_ = data;
|
||||
}
|
||||
|
||||
// move thread's window adaptor coordinate and bottom tensor coordinate
|
||||
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
|
||||
template <typename ATopIndex>
|
||||
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
WindowAdaptorCoord& window_adaptor_thread_coord,
|
||||
BottomTensorCoord& bottom_tensor_thread_coord,
|
||||
const ATopIndex& idx_diff_adaptor_top) const
|
||||
{
|
||||
array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
|
||||
|
||||
move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
window_adaptor_thread_coord,
|
||||
idx_diff_adaptor_top,
|
||||
idx_diff_adaptor_bottom);
|
||||
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
bottom_tensor_thread_coord,
|
||||
idx_diff_adaptor_bottom);
|
||||
}
|
||||
|
||||
// return vector dimension among [y0, y1, ...]
|
||||
CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
|
||||
{
|
||||
// bottom tensor top dimension vector lengths and strides
|
||||
const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
|
||||
BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
|
||||
|
||||
// window vector lengths/strides
|
||||
const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
|
||||
const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
|
||||
|
||||
// window adaptor [p0, p1, ..., y0, y1, ...]
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
|
||||
-1};
|
||||
array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
|
||||
-1};
|
||||
|
||||
constexpr auto window_adaptor_bottom_dims =
|
||||
WindowAdaptor::get_bottom_dimension_hidden_ids();
|
||||
|
||||
set_container_subset(window_adaptor_vector_lengths,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_lengths);
|
||||
set_container_subset(window_adaptor_vector_strides,
|
||||
window_adaptor_bottom_dims,
|
||||
window_adaptor_bottom_dim_vector_strides);
|
||||
|
||||
const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
|
||||
WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
|
||||
window_adaptor_vector_lengths, window_adaptor_vector_strides);
|
||||
|
||||
// [y0, y1, ...]
|
||||
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
|
||||
NDimWindowAdaptorTop,
|
||||
1>::type{};
|
||||
|
||||
return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
|
||||
get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto get_num_of_access() const { return load_store_traits::NumAccess; }
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
|
||||
load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
return dst_tensor;
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_m = idx_ys_start[number<YsGatherDim>{}];
|
||||
const auto page_offset = page_idx_[idx_m];
|
||||
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, page_offset, bool_constant<oob_conditional_check>{});
|
||||
#if 1
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value.template get_as<DataType>()[j / Traits::PackedSize];
|
||||
});
|
||||
#else
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
static_assert(d % Traits::ScalarPerVector == 0);
|
||||
|
||||
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
|
||||
number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
|
||||
#endif
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
forward_step_scatter);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = load_store_traits;
|
||||
|
||||
// using vector_type_t = typename Traits::vector_type_t;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
constexpr auto tile_dstr = TileDstr{};
|
||||
// printf("off %d\n", page_idx_[I0]);
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
// BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
// window_origin_ +
|
||||
// tuple<index_t, index_t>(0, window_adaptor_thread_coord.get_bottom_index()[1]);
|
||||
|
||||
// auto bottom_tensor_thread_coord = make_tensor_coordinate(
|
||||
// bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// data index [y0, y1, ...]
|
||||
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
constexpr auto idx_m = idx_ys_start[number<0>{}];
|
||||
const auto page_offset = page_idx_[idx_m];
|
||||
|
||||
// printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
|
||||
// idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
|
||||
|
||||
// read from distributed tensor
|
||||
// vector_type_t vec;
|
||||
vector_t vec_value;
|
||||
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
[&](auto jj) {
|
||||
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
: idx_ys_start[jj];
|
||||
},
|
||||
number<NDimY>{});
|
||||
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
Traits::PackedSize;
|
||||
// printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j);
|
||||
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
dstr_tensor.get_thread_buffer().template at<d>();
|
||||
});
|
||||
|
||||
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
|
||||
|
||||
// write into bottom tensor
|
||||
get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord,
|
||||
page_offset,
|
||||
vec_value,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
// printf("coord_offset:%d, scatter_offset:%d \n",
|
||||
// bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
constexpr auto forward_step_scatter = generate_tuple(
|
||||
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, number<NDimY>{});
|
||||
|
||||
constexpr auto idx_diff_ps_ys = container_concat(
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
forward_step_scatter);
|
||||
|
||||
move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// move thread's botom tensor coordiante
|
||||
// [x0', x1', ... ] ==> [offset]
|
||||
// also move window-origin
|
||||
CK_TILE_DEVICE void move(const BottomTensorIndex& step)
|
||||
{
|
||||
window_origin_ += step;
|
||||
BottomTensorIndex step_new = step;
|
||||
step_new(HsGatherDim) = 0;
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
|
||||
pre_computed_coords_(iCoord)(I1),
|
||||
step_new);
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx)
|
||||
{
|
||||
page_idx_ = new_idx;
|
||||
|
||||
// static_for<0, 2, 1>{}([&](auto k0) {
|
||||
// printf("update tid %d %d \n", threadIdx.x, page_idx_[k0]);
|
||||
// });
|
||||
}
|
||||
// CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
|
||||
// {
|
||||
// window_origin_ = new_window_origin;
|
||||
|
||||
// #if 0 // debug
|
||||
// // TODO: this use more register for FA, but less register for GEMM
|
||||
// // need investigation
|
||||
// // only support warp-tile and block-tile
|
||||
// static_assert(NDimP == 1 or NDimP == 2, "wrong!");
|
||||
|
||||
// WindowAdaptorCoord window_adaptor_thread_coord_tmp;
|
||||
|
||||
// if constexpr(NDimP == 1)
|
||||
// {
|
||||
// window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
// tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
|
||||
// }
|
||||
// else if constexpr(NDimP == 2)
|
||||
// {
|
||||
// window_adaptor_thread_coord_tmp =
|
||||
// make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
// AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
|
||||
// }
|
||||
// #else
|
||||
// // TODO: this use less register for FA, but more register for GEMM
|
||||
// // need investigation
|
||||
// const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
|
||||
// tile_dstr_.get_ps_ys_to_xs_adaptor(),
|
||||
// container_concat(detail::get_partition_index(tile_dstr_), array<index_t, NDimY>{0}));
|
||||
// #endif
|
||||
|
||||
// BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
// window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
|
||||
// const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
|
||||
// bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
// // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
|
||||
// // future load/store() calls (might allocate more registers)
|
||||
// using Traits = load_store_traits;
|
||||
// using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
// static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
// auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
|
||||
// auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
|
||||
|
||||
// constexpr auto idx_diff_ys =
|
||||
// SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
|
||||
|
||||
// constexpr auto idx_diff_ps_ys = container_concat(
|
||||
// generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
|
||||
|
||||
// move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
// window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
|
||||
// pre_computed_coords_(iCoord) =
|
||||
// make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
|
||||
// });
|
||||
// }
|
||||
|
||||
CK_TILE_HOST_DEVICE void init_raw() { bottom_tensor_view_.init_raw(); }
|
||||
|
||||
// this is the bottom tensor view
|
||||
// [x0', x1', ...] ==> [offset]
|
||||
BottomTensorView bottom_tensor_view_;
|
||||
|
||||
//
|
||||
WindowLengths window_lengths_;
|
||||
|
||||
// origin ([x0', x1', ...]) of window on bottom tensor
|
||||
BottomTensorIndex window_origin_;
|
||||
|
||||
// Tile tensor distribution, which contains:
|
||||
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
|
||||
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
|
||||
TileDstr tile_dstr_;
|
||||
|
||||
PageIdxArray page_idx_;
|
||||
|
||||
// this contains:
|
||||
// per-thread coordinate for window adaptor
|
||||
// per-thread coordinate for bottom tensor
|
||||
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
|
||||
};
|
||||
|
||||
// TODO: use strategy
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename StaticPageIndexArray_,
|
||||
index_t HsGatherDim = 0,
|
||||
index_t NumCoord = 1>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_scatter_gather(const TensorView_& tensor_view,
|
||||
const WindowLengths_& window_lengths,
|
||||
const multi_index<TensorView_::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution_& tile_distribution,
|
||||
const StaticPageIndexArray_& page_idx,
|
||||
number<HsGatherDim> = {},
|
||||
number<NumCoord> = {})
|
||||
{
|
||||
return tile_scatter_gather<remove_cvref_t<TensorView_>,
|
||||
remove_cvref_t<WindowLengths_>,
|
||||
remove_cvref_t<StaticTileDistribution_>,
|
||||
remove_cvref_t<StaticPageIndexArray_>,
|
||||
HsGatherDim,
|
||||
NumCoord>{
|
||||
tensor_view, window_lengths, origin, tile_distribution, page_idx};
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution, typename StaticPageIndexArray, index_t HsGatherDim>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_scatter_gather(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const multi_index<TensorView::get_num_of_dimension()>& origin,
|
||||
const StaticTileDistribution& tile_distribution,
|
||||
const StaticPageIndexArray& page_idx,
|
||||
number<HsGatherDim> = {})
|
||||
{
|
||||
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
|
||||
tile_window.get_window_lengths(),
|
||||
origin,
|
||||
tile_distribution,
|
||||
page_idx,
|
||||
number<HsGatherDim>{});
|
||||
}
|
||||
|
||||
template <typename TensorView, typename WindowLengths, typename StaticTileDistribution, typename StaticPageIndexArray, index_t HsGatherDim>
|
||||
CK_TILE_DEVICE constexpr auto
|
||||
make_tile_scatter_gather(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
const StaticTileDistribution& tile_distribution, const StaticPageIndexArray& page_idx,
|
||||
number<HsGatherDim> = {})
|
||||
{
|
||||
return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
|
||||
tile_window.get_window_lengths(),
|
||||
tile_window.get_window_origin(),
|
||||
tile_distribution,
|
||||
page_idx,
|
||||
number<HsGatherDim>{});
|
||||
}
|
||||
|
||||
// template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
|
||||
// CK_TILE_DEVICE constexpr auto
|
||||
// make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLengths>& tile_window,
|
||||
// const StaticTileDistribution& tile_distribution)
|
||||
// {
|
||||
// auto w = make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
|
||||
// tile_window.get_window_lengths(),
|
||||
// tile_window.get_window_origin(),
|
||||
// tile_distribution);
|
||||
// w.init_raw();
|
||||
// return w;
|
||||
// }
|
||||
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -609,6 +609,93 @@ struct tile_window_with_static_distribution
|
||||
});
|
||||
}
|
||||
|
||||
// template <typename statically_indexed_array,
|
||||
// index_t i_access_unsupport_ = -1,
|
||||
// bool oob_conditional_check = true>
|
||||
// CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
// const statically_indexed_array& offsets,
|
||||
// number<i_access_unsupport_> = {},
|
||||
// bool_constant<oob_conditional_check> = {}) const
|
||||
// {
|
||||
// using Traits = load_store_traits;
|
||||
|
||||
// // using vector_type_t = typename Traits::vector_type_t;
|
||||
// using vector_t = typename Traits::vector_t;
|
||||
// using SFC_Ys = typename Traits::SFC_Ys;
|
||||
|
||||
// constexpr auto tile_dstr = TileDstr{};
|
||||
|
||||
// // loop over thread tensor space [y0, y1, ...]
|
||||
// static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
// auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
// // auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
|
||||
|
||||
// BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
// window_origin_ +
|
||||
// tuple<index_t, index_t>(0, window_adaptor_thread_coord.get_bottom_index()[1]);
|
||||
|
||||
// auto bottom_tensor_thread_coord = make_tensor_coordinate(
|
||||
// bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
// static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
// constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// // data index [y0, y1, ...]
|
||||
// constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
|
||||
// constexpr auto idx_m = idx_ys_start[number<0>{}];
|
||||
// const auto offset = offsets[idx_m];
|
||||
|
||||
// // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
|
||||
// // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
|
||||
|
||||
// // read from distributed tensor
|
||||
// // vector_type_t vec;
|
||||
// vector_t vec_value;
|
||||
|
||||
// static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
// constexpr auto idx_ys = generate_tuple(
|
||||
// [&](auto jj) {
|
||||
// return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
|
||||
// : idx_ys_start[jj];
|
||||
// },
|
||||
// number<NDimY>{});
|
||||
|
||||
// constexpr index_t d =
|
||||
// tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
|
||||
// Traits::PackedSize;
|
||||
// // printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j);
|
||||
// vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
|
||||
// dstr_tensor.get_thread_buffer().template at<d>();
|
||||
// });
|
||||
|
||||
// // const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
|
||||
|
||||
// // write into bottom tensor
|
||||
// get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
|
||||
// bottom_tensor_thread_coord,
|
||||
// offset,
|
||||
// vec_value,
|
||||
// bool_constant<oob_conditional_check>{});
|
||||
// // printf("coord_offset:%d, scatter_offset:%d \n",
|
||||
// // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
|
||||
// if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
// {
|
||||
// constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
|
||||
|
||||
// constexpr auto forward_step_scatter = generate_tuple(
|
||||
// [&](auto i) { return i == 0 ? 0 : idx_diff_ys[i]; }, number<NDimY>{});
|
||||
|
||||
// constexpr auto idx_diff_ps_ys = container_concat(
|
||||
// generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
|
||||
// forward_step_scatter);
|
||||
|
||||
// move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
// window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
// }
|
||||
// });
|
||||
// });
|
||||
// }
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void store(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
@@ -1010,23 +1097,6 @@ make_tile_window_raw(const TensorView_& tensor_view,
|
||||
return w;
|
||||
}
|
||||
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
index_t NumCoord>
|
||||
CK_TILE_DEVICE void move_tile_window(
|
||||
tile_window_with_static_distribution<TensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>& window,
|
||||
const typename tile_window_with_static_distribution<TensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
NumCoord>::BottomTensorIndex& step)
|
||||
{
|
||||
window.move(step);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This class provides description of tile windowed view on the device memory.
|
||||
*
|
||||
@@ -1155,13 +1225,4 @@ make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLen
|
||||
return w;
|
||||
}
|
||||
|
||||
template <typename TensorView_, typename WindowLengths_>
|
||||
CK_TILE_DEVICE void move_tile_window(
|
||||
tile_window_with_static_lengths<TensorView_, WindowLengths_>& window,
|
||||
const typename tile_window_with_static_lengths<TensorView_, WindowLengths_>::BottomTensorIndex&
|
||||
step)
|
||||
{
|
||||
window.move(step);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1200,19 +1200,4 @@ make_tile_window_linear_raw(const TileWindow_& tile_window,
|
||||
LinearBottomDims_{});
|
||||
}
|
||||
|
||||
template <typename TensorView_,
|
||||
typename WindowLengths_,
|
||||
typename StaticTileDistribution_,
|
||||
typename LinearBottomDims_>
|
||||
CK_TILE_DEVICE void move_tile_window(
|
||||
tile_window_linear<TensorView_, WindowLengths_, StaticTileDistribution_, LinearBottomDims_>&
|
||||
window,
|
||||
const typename tile_window_linear<TensorView_,
|
||||
WindowLengths_,
|
||||
StaticTileDistribution_,
|
||||
LinearBottomDims_>::BottomTensorIndex& step)
|
||||
{
|
||||
window.move(step);
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -18,6 +18,14 @@
|
||||
#pragma once
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TileWindow_>
|
||||
CK_TILE_DEVICE void move_tile_window(
|
||||
TileWindow_& window,
|
||||
const typename TileWindow_::BottomTensorIndex& step)
|
||||
{
|
||||
window.move(step);
|
||||
}
|
||||
|
||||
// input a lds store tile, extract some information from it
|
||||
// used to set m0 value for gfx9 serious
|
||||
template <typename LdsTileWindow_>
|
||||
|
||||
@@ -0,0 +1,236 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// template <typename ADataType,
|
||||
// typename BDataType,
|
||||
// typename AccDataType,
|
||||
// typename CDataType,
|
||||
// typename AElementOp = ck_tile::identity,
|
||||
// typename BElementOp = ck_tile::identity,
|
||||
// typename ACCElementOp = ck_tile::identity>
|
||||
// CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
// const HostTensor<BDataType>& b_k_n,
|
||||
// HostTensor<CDataType>& c_m_n,
|
||||
// const AElementOp& a_element_op = {},
|
||||
// const BElementOp& b_element_op = {},
|
||||
// const ACCElementOp& acc_element_op = {})
|
||||
// {
|
||||
// const std::size_t M = a_m_k.get_length(0);
|
||||
// const std::size_t N = b_k_n.get_length(1);
|
||||
// const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
// auto f_mn = [&](auto m, auto n) {
|
||||
// AccDataType v_acc = 0;
|
||||
|
||||
// for(std::size_t k = 0; k < K; ++k)
|
||||
// {
|
||||
// AccDataType v_a;
|
||||
// AccDataType v_b;
|
||||
// if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
// {
|
||||
// const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
|
||||
// const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
// if(k % 2 == 1)
|
||||
// v_a = fp32_val.hi;
|
||||
// else
|
||||
// v_a = fp32_val.lo;
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
// }
|
||||
// if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
// {
|
||||
// const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
|
||||
// const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
// if(k % 2 == 1)
|
||||
// v_b = fp32_val.hi;
|
||||
// else
|
||||
// v_b = fp32_val.lo;
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
// }
|
||||
// v_acc += v_a * v_b;
|
||||
// }
|
||||
|
||||
// c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
// };
|
||||
|
||||
// make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
// }
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
bool IsInputGemm = true>
|
||||
__global__ void naive_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
const ck_tile::index_t* p_sorted_expert_ids_,
|
||||
const ck_tile::index_t* p_max_token_id_,
|
||||
const ADataType* A,
|
||||
const BDataType* B,
|
||||
CDataType* C,
|
||||
ck_tile::index_t Num_tokens,
|
||||
ck_tile::index_t TopK,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t strideA,
|
||||
ck_tile::index_t strideB,
|
||||
ck_tile::index_t strideC)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int row = idx / N; // Compute row index
|
||||
int col = idx % N; // Compute column index
|
||||
(void)Num_tokens;
|
||||
// assert(p_sorted_expert_ids_ != nullptr);
|
||||
// assert(TopK == 1);
|
||||
// assert(Num_tokens == 128);
|
||||
// if(Num_tokens == 128 && TopK == 1 && p_sorted_expert_ids_ != nullptr) {}
|
||||
|
||||
// index_t max_tokens = p_max_token_id_[0];
|
||||
index_t gather_token_id = 0;
|
||||
index_t scatter_token_id = 0;
|
||||
index_t expert_id = 0;
|
||||
|
||||
if(row < p_max_token_id_[0])
|
||||
{
|
||||
expert_id = p_sorted_expert_ids_[row / 128];
|
||||
gather_token_id = p_sorted_token_ids_[row] & 0xffffff;
|
||||
scatter_token_id = p_sorted_token_ids_[row] & 0xffffff;
|
||||
if(!IsInputGemm)
|
||||
{
|
||||
gather_token_id = gather_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
|
||||
}
|
||||
else
|
||||
{
|
||||
scatter_token_id = scatter_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if(row < M && col < N)
|
||||
{
|
||||
AccDataType acc = 0.0;
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
// Adjust indexing based on matrix layout
|
||||
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? gather_token_id * strideA + k
|
||||
: k * strideA + gather_token_id;
|
||||
|
||||
// TODO: add experts weights dispatch
|
||||
int b_index =
|
||||
expert_id * N * K + ((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? col * strideB + k
|
||||
: k * strideB + col);
|
||||
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
}
|
||||
acc += v_a * v_b;
|
||||
}
|
||||
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? scatter_token_id * strideC + col
|
||||
: col * strideC + scatter_token_id;
|
||||
C[c_index] = ck_tile::type_convert<CDataType>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
bool IsInputGemm = true>
|
||||
void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
|
||||
const index_t* p_sorted_expert_ids_,
|
||||
const index_t* p_max_token_id_,
|
||||
const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
index_t Num_tokens,
|
||||
index_t TopK,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_a,
|
||||
index_t stride_b,
|
||||
index_t stride_c)
|
||||
{
|
||||
int totalElements = M * N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
naive_gemm_kernel<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
LayoutA,
|
||||
LayoutB,
|
||||
LayoutC,
|
||||
IsInputGemm><<<numBlocks, numThreadsPerBlock>>>(p_sorted_token_ids_,
|
||||
p_sorted_expert_ids_,
|
||||
p_max_token_id_,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
Num_tokens,
|
||||
TopK,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,7 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
@@ -119,6 +119,124 @@ struct CShuffleEpilogue
|
||||
return kMWave * kNWave * kMPerXdl * kNPerXdl * sizeof(ODataType);
|
||||
}
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
bool IsInputGemm = true,
|
||||
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* p_smem,
|
||||
const index_t* p_sorted_tokens_id,
|
||||
index_t token_pos)
|
||||
{
|
||||
|
||||
const index_t iMWarp = get_warp_id() / kNWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * kNWave;
|
||||
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<ODataType*>(p_smem), lds_block_desc);
|
||||
auto in_lds_window =
|
||||
make_tile_window(o_lds_block,
|
||||
make_tuple(number<kMPerXdl>{}, number<kNPerXdl>{}),
|
||||
{number<kMPerXdl>{} * iMWarp, number<kNPerXdl>{} * iNWarp});
|
||||
auto out_lds_window =
|
||||
make_tile_window(o_lds_block,
|
||||
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
|
||||
{0, 0});
|
||||
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<kMPerXdl * kMWave, kNPerXdl * kNWave>>;
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPattern2D<kBlockSize,
|
||||
kMPerIteration,
|
||||
kNPerIteration,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked>;
|
||||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
// auto coord = dram_tile_distribution.calculate_index();
|
||||
|
||||
// const auto& view = out_dram_window.get_bottom_tensor_view();
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
auto c_coord = dram_tile_distribution.calculate_index();
|
||||
|
||||
// printf("c_coord[0]:%d \n", c_coord[0]);
|
||||
|
||||
CWarpTensor c_warp_in_tensor;
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
|
||||
// auto idx_m = number<idx_y_start.at(number<0>{})>{} + 0;
|
||||
// printf("idx_y_start:%d \n", idx_m);
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (kMPerXdl * kMWave)>{};
|
||||
|
||||
statically_indexed_array<index_t, 2> offsets;
|
||||
static_for<0, 2 /*CMrepeats*/, 1>{}([&](auto m0) {
|
||||
auto token_id = token_pos + m0 + c_coord[0] + mIter * kMPerXdl * kMWave;
|
||||
auto fused_token = p_sorted_tokens_id[token_id];
|
||||
|
||||
index_t token_offset = fused_token & 0xffffff;
|
||||
|
||||
if constexpr(IsInputGemm)
|
||||
{
|
||||
token_offset = token_offset * 3 /*TopK*/ + (fused_token >> 24);
|
||||
}
|
||||
|
||||
offsets[m0] = token_offset * 4096; // Problem::kN_;
|
||||
});
|
||||
// printf("c_coord[number<0>{}]: %d \n", coord[number<0>{}]);
|
||||
// printf("mIter: %d", mIter+0);
|
||||
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (kNPerXdl * kNWave)>{};
|
||||
|
||||
// printf("mIter, nIter(%d, %d) \n", mIter+0, nIter+0);
|
||||
|
||||
c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
const auto c_warp_in_tensor_casted = cast_tile<ODataType>(c_warp_in_tensor);
|
||||
|
||||
block_sync_lds();
|
||||
store_tile(in_lds_window, c_warp_in_tensor_casted);
|
||||
block_sync_lds();
|
||||
|
||||
const auto c_out_tensor =
|
||||
load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
|
||||
if constexpr(out_memory_data_op == memory_operation_enum::set)
|
||||
{
|
||||
|
||||
|
||||
auto tile_window = make_tile_scatter_gather(out_dram_window.get_bottom_tensor_view(),
|
||||
out_dram_window.get_window_lengths(),
|
||||
out_dram_window.get_window_origin(),
|
||||
dram_tile_distribution,
|
||||
offsets);
|
||||
|
||||
tile_window.store(c_out_tensor);
|
||||
// store_tile(out_dram_window, c_out_tensor, offsets);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
move_tile_window(out_dram_window, {0, step.at(number<1>{})});
|
||||
// printf("step.at(number<0>{}), step.at(number<1>{}):,%d, %d",
|
||||
// step.at(number<0>{})+0, step.at(number<1>{})+0);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
|
||||
|
||||
@@ -265,6 +265,7 @@ struct FmhaFwdKernel
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
const int32_t* page_idx;
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
|
||||
@@ -596,6 +597,7 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* page_idx_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -654,7 +656,8 @@ struct FmhaFwdKernel
|
||||
{}, // placeholder for dropout
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(page_idx_ptr)};
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -720,6 +723,7 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* page_idx_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -758,6 +762,7 @@ struct FmhaFwdKernel
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
page_idx_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -799,6 +804,7 @@ struct FmhaFwdKernel
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
const void* page_idx_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -837,6 +843,7 @@ struct FmhaFwdKernel
|
||||
seqstart_q_ptr,
|
||||
seqstart_k_ptr,
|
||||
seqlen_k_ptr,
|
||||
page_idx_ptr,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
num_head_q,
|
||||
@@ -958,7 +965,7 @@ struct FmhaFwdKernel
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
// long_index_t batch_offset_k = 0;
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_randval = 0;
|
||||
@@ -972,7 +979,7 @@ struct FmhaFwdKernel
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
// batch_offset_k = key_start * kargs.stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
@@ -981,6 +988,8 @@ struct FmhaFwdKernel
|
||||
{
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
|
||||
kargs.page_idx += key_start;
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start * kargs.stride_bias + key_start;
|
||||
@@ -1016,35 +1025,38 @@ struct FmhaFwdKernel
|
||||
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
|
||||
}
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
batch_offset_randval =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
}
|
||||
// else
|
||||
// {
|
||||
// batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
// batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
// batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
// if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
// {
|
||||
// batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
|
||||
// }
|
||||
// if constexpr(kStoreLSE)
|
||||
// {
|
||||
// batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
// }
|
||||
// if constexpr(kHasDropout)
|
||||
// {
|
||||
// batch_offset_randval =
|
||||
// static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
|
||||
// }
|
||||
// batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
// }
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
|
||||
batch_offset_q;
|
||||
// const KDataType* k_ptr =
|
||||
// reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
// static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
|
||||
// batch_offset_k;
|
||||
const KDataType* k_ptr =
|
||||
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
|
||||
batch_offset_k;
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k;
|
||||
const VDataType* v_ptr =
|
||||
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
|
||||
@@ -1329,6 +1341,9 @@ struct FmhaFwdKernel
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
smem_ptr,
|
||||
kargs.page_idx,
|
||||
kargs.stride_k,
|
||||
kargs.stride_v,
|
||||
dropout);
|
||||
}
|
||||
else
|
||||
@@ -1343,6 +1358,9 @@ struct FmhaFwdKernel
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
smem_ptr,
|
||||
kargs.page_idx,
|
||||
kargs.stride_k,
|
||||
kargs.stride_v,
|
||||
dropout);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -8,6 +8,9 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/core/tensor/tile_scatter_gather.hpp"
|
||||
// #include "ck_tile/core/tensor/tile_scatter_gather_debug.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -45,6 +48,10 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
static constexpr auto I3 = number<3>{};
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
@@ -148,6 +155,9 @@ struct BlockFmhaPipelineQRKSVS
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr,
|
||||
const index_t* page_idx,
|
||||
const index_t stride_k,
|
||||
const index_t stride_v,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
@@ -241,11 +251,10 @@ struct BlockFmhaPipelineQRKSVS
|
||||
return o_acc;
|
||||
}
|
||||
}
|
||||
|
||||
auto k_dram_block_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0});
|
||||
{seqlen_k_start, 0}); //todo fixme felix
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
@@ -257,11 +266,28 @@ struct BlockFmhaPipelineQRKSVS
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
|
||||
randval_dram_block_window_tmp, seqlen_k_start);
|
||||
|
||||
auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
|
||||
auto v_coord = v_dist.calculate_index();
|
||||
const auto VPageIndexDim = I1;
|
||||
using VDstrEncode = typename decltype(v_dist)::DstrEncode;
|
||||
constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3];
|
||||
statically_indexed_array<index_t, V_KRepeat> v_offsets;
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v;
|
||||
// printf("1tid %d %d %d %d %d\n", threadIdx.x, v_coord[VPageIndexDim], k0.value, page_idx[v_coord[VPageIndexDim] + k0.value], stride_v);
|
||||
});
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
v_dist,
|
||||
v_offsets,
|
||||
VPageIndexDim);
|
||||
// auto v_dram_window =
|
||||
// make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
// v_dram_block_window_tmp.get_window_lengths(),
|
||||
// {0, seqlen_k_start}, // TODO: hdim split?
|
||||
// v_dist);
|
||||
|
||||
auto q_tile = tile_elementwise_in(q_element_func, q);
|
||||
|
||||
@@ -275,12 +301,20 @@ struct BlockFmhaPipelineQRKSVS
|
||||
do
|
||||
{
|
||||
// STAGE 1, QK gemm
|
||||
auto k_dram_window = make_tile_window(
|
||||
auto k_dist = Policy::template MakeKDramTileDistribution<Problem>();
|
||||
auto k_coord = k_dist.calculate_index();
|
||||
using KDstrEncode = typename decltype(k_dist)::DstrEncode;
|
||||
constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
|
||||
statically_indexed_array<index_t, NRepeat> k_offsets;
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
|
||||
});
|
||||
auto k_dram_window = make_tile_scatter_gather(
|
||||
k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
|
||||
// load
|
||||
k_dist,
|
||||
k_offsets); // K DRAM tile window for
|
||||
|
||||
auto k_block_tile = load_tile(k_dram_window);
|
||||
{
|
||||
@@ -321,7 +355,13 @@ struct BlockFmhaPipelineQRKSVS
|
||||
});
|
||||
}
|
||||
|
||||
const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
|
||||
const auto v_prefetch = v_dram_window.load(); // prefetch load v tile
|
||||
// const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v;
|
||||
// printf("2tid %d %d %d %d\n", threadIdx.x, v_coord[VPageIndexDim], kK1 + v_coord[VPageIndexDim] + k0.value, page_idx[kK1 + v_coord[VPageIndexDim] + k0.value]);
|
||||
});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
gemm_0(s_acc,
|
||||
@@ -523,6 +563,12 @@ struct BlockFmhaPipelineQRKSVS
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
|
||||
static_for<0, V_KRepeat, 1>{}([&](auto k0) {
|
||||
v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v;
|
||||
// printf("3tid %d %d %d %d\n", threadIdx.x, v_coord[VPageIndexDim], kK1 * 2 + i_k1.value * kK1 + v_coord[VPageIndexDim] + k0.value, page_idx[kK1 + i_k1.value * kK1 + v_coord[VPageIndexDim] + k0.value]);
|
||||
});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
@@ -556,6 +602,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
page_idx += kN0;
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// store lse
|
||||
@@ -626,6 +673,9 @@ struct BlockFmhaPipelineQRKSVS
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr,
|
||||
const index_t* page_idx,
|
||||
const index_t stride_k,
|
||||
const index_t stride_v,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
@@ -646,6 +696,9 @@ struct BlockFmhaPipelineQRKSVS
|
||||
position_encoding,
|
||||
scale_s,
|
||||
smem_ptr,
|
||||
page_idx,
|
||||
stride_k,
|
||||
stride_v,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -623,7 +623,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(VDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>(); //
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
constexpr index_t NPerRow = PixelsPerRow / kKPack;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
@@ -783,19 +783,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1; // P
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>(); // 8
|
||||
constexpr index_t N0 = kNPerBlock / N1; // P // 16
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t K3 = total_pixels / N1; //2
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>(); // 8
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
constexpr index_t K2 = kKPack / K3; // //4 TODO: this dimention could be outside single wave
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0); // 2
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size(); // 2
|
||||
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
|
||||
@@ -19,7 +19,6 @@ cmake
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D BUILD_DEV=ON \
|
||||
-D GPU_TARGETS=$GPU_TARGETS \
|
||||
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
|
||||
-D USE_BITINT_EXTENSION_INT4=OFF \
|
||||
|
||||
Reference in New Issue
Block a user