mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Merge branch 'develop' into users/ArthurLiu/ck_fmha_codegen
This commit is contained in:
@@ -22,6 +22,7 @@ RUN groupadd -g 109 render && \
|
||||
chmod -R a+rwx /tmp/pytorch && \
|
||||
sudo usermod -aG irc jenkins && \
|
||||
#install hipblaslt
|
||||
cd /tmp && \
|
||||
git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && \
|
||||
cd rocm-libraries && \
|
||||
git checkout develop && \
|
||||
@@ -29,4 +30,4 @@ RUN groupadd -g 109 render && \
|
||||
git sparse-checkout set projects/hipblaslt shared/origami && \
|
||||
cd projects/hipblaslt && \
|
||||
git show --oneline -s && \
|
||||
CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx942;gfx950" -j 128 --skip_rocroller
|
||||
CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --use-system-packages --architecture="gfx942;gfx950" -j 128 --skip_rocroller
|
||||
|
||||
@@ -10,7 +10,7 @@ if(NOT INST_TARGETS)
|
||||
endif()
|
||||
|
||||
# validate user-specified fmha_fwd API list
|
||||
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill")
|
||||
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill;batch_prefill")
|
||||
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
if(BUILD_TESTING)
|
||||
@@ -48,7 +48,6 @@ set(FMHA_FWD_CODE_GEN_COMMON_ARGS
|
||||
--targets ${FMHA_TARGETS_ARG}
|
||||
--api ${FMHA_FWD_APIS}
|
||||
--optdim 32,64,80,128,256
|
||||
# --filter fmha_fwd...
|
||||
)
|
||||
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
|
||||
${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
@@ -174,6 +173,13 @@ else()
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0)
|
||||
endif()
|
||||
|
||||
# conditionally enable call to the batch_prefill API in fmha_fwd example and tests
|
||||
if("batch_prefill" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_BATCH_PREFILL_API=1)
|
||||
else()
|
||||
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_BATCH_PREFILL_API=0)
|
||||
endif()
|
||||
|
||||
# conditionally specify the use of OCP_FP8
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
|
||||
@@ -84,6 +84,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad},
|
||||
{F_qscale},
|
||||
{F_occupancy},
|
||||
false,
|
||||
{F_sink},
|
||||
{F_page_size},
|
||||
{F_kv_memory_layout},
|
||||
{F_kv_lookup_table}>;
|
||||
@@ -124,7 +125,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -201,9 +202,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) &&
|
||||
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{
|
||||
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -247,6 +248,7 @@ class FmhaFwdApiTrait:
|
||||
skpad: str
|
||||
dpad: str
|
||||
dvpad: str
|
||||
sink: str # t/f
|
||||
constraint: CppConstraint
|
||||
kv_memory_layout: str
|
||||
kv_lookup_table: str
|
||||
@@ -343,6 +345,7 @@ class FmhaFwdPipeline:
|
||||
F_dropout: str #
|
||||
F_qscale: str # no/pertensor
|
||||
F_mask: str # value from MASK_MAP
|
||||
F_sink: str # t/f (StreamLLM sink tokens)
|
||||
F_kv_memory_layout: str #
|
||||
F_kv_lookup_table: str #
|
||||
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
@@ -406,6 +409,11 @@ class FmhaFwdPipeline:
|
||||
else:
|
||||
n += "_nqscale"
|
||||
|
||||
if self.F_sink == "t":
|
||||
n += "_sink"
|
||||
else:
|
||||
n += "_nsink"
|
||||
|
||||
n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table
|
||||
return n
|
||||
|
||||
@@ -472,6 +480,7 @@ class FmhaFwdApiPool:
|
||||
trait.kv_lookup_table
|
||||
],
|
||||
F_page_size=trait.page_size,
|
||||
F_sink=BOOL_MAP[trait.sink],
|
||||
)
|
||||
if_j = "if" if j == 0 else "else if"
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
|
||||
@@ -578,6 +587,7 @@ class FmhaFwdKernel:
|
||||
F_mode=MODE_MAP[self.F_mode],
|
||||
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_page_size=self.F_page_size,
|
||||
F_sink=BOOL_MAP[self.F_pipeline.F_sink],
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -617,6 +627,7 @@ class FmhaFwdKernel:
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
sink=self.F_pipeline.F_sink,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
|
||||
kv_memory_layout=self.F_pipeline.F_kv_memory_layout,
|
||||
kv_lookup_table=self.F_pipeline.F_kv_lookup_table,
|
||||
@@ -655,6 +666,7 @@ class KernelComponentFactory:
|
||||
bias,
|
||||
lse,
|
||||
dropout,
|
||||
sink,
|
||||
kv_memory_layout,
|
||||
kv_lookup_table,
|
||||
) in itertools.product(
|
||||
@@ -663,12 +675,13 @@ class KernelComponentFactory:
|
||||
BIAS_MAP.keys(),
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
["t", "f"],
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
elif dtype in ["fp8bf16"]:
|
||||
# no need lse/dropout kernels
|
||||
# no need lse/dropout/sink kernels
|
||||
for (
|
||||
logits,
|
||||
qscale,
|
||||
@@ -684,7 +697,7 @@ class KernelComponentFactory:
|
||||
SUPPORTED_KV_MEMORY_LAYOUT,
|
||||
SUPPORTED_KV_LOOKUP_TABLE,
|
||||
):
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", kv_memory_layout, kv_lookup_table)) # fmt: skip
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
@@ -701,20 +714,34 @@ class CustomFactory(KernelComponentFactory):
|
||||
|
||||
|
||||
def get_fwd_blobs(
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
|
||||
kernel_filter: Optional[str], receipt, optdim_list, mask_impl,
|
||||
targets: Optional[List[str]] = None
|
||||
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing
|
||||
# (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with
|
||||
# non-gfx9 architectures (gfx11/gfx12/gfx10 are wave32 and use different
|
||||
# buffer instruction formats). Skip all batch_prefill kernels for non-gfx9 targets.
|
||||
has_non_gfx9 = targets is not None and any(
|
||||
not t.startswith("gfx9") for t in targets
|
||||
)
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
if has_non_gfx9:
|
||||
return api_pool, gen
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = CustomFactory.get_hdim_tile_size_dict(dtype)
|
||||
if d is None:
|
||||
continue
|
||||
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
|
||||
# batch_prefill pipeline requires group mode (static_assert in pipeline problem)
|
||||
if mode != "group":
|
||||
continue
|
||||
for tile, pipeline in itertools.product(
|
||||
tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)
|
||||
):
|
||||
@@ -829,7 +856,7 @@ def write_blobs(
|
||||
optdim_list,
|
||||
mask_impl,
|
||||
) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
@@ -844,7 +871,7 @@ def list_blobs(
|
||||
mask_impl,
|
||||
) -> None:
|
||||
with file_path.open("a") as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets)
|
||||
for kernel in kernels:
|
||||
f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n")
|
||||
f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n")
|
||||
|
||||
@@ -1452,6 +1452,7 @@ template <ck_tile::index_t HDim_,
|
||||
bool kPadDv_,
|
||||
bool kUseTrLoad_,
|
||||
bool kSkipMinSeqlenQ_ = false,
|
||||
bool kHasSink_ = false,
|
||||
ck_tile::index_t kPageBlockSize_ = 1,
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
|
||||
@@ -1480,7 +1481,7 @@ struct fmha_fwd_batch_prefill_traits_ : public fmha_fwd_traits_<HDim_,
|
||||
kPadDv_,
|
||||
kUseTrLoad_,
|
||||
kSkipMinSeqlenQ_,
|
||||
false>
|
||||
kHasSink_>
|
||||
{
|
||||
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
|
||||
static constexpr auto kKVLookupTable = kKVLookupTable_;
|
||||
|
||||
@@ -387,7 +387,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
|
||||
#if(!(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || \
|
||||
CK_TILE_FMHA_FWD_PAGEDKV_API))
|
||||
CK_TILE_FMHA_FWD_PAGEDKV_API || CK_TILE_FMHA_FWD_BATCH_PREFILL_API))
|
||||
if(0 < page_block_size)
|
||||
{
|
||||
std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option"
|
||||
@@ -395,7 +395,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
page_block_size = 0;
|
||||
}
|
||||
#endif
|
||||
if(!(page_block_size % 128 == 0))
|
||||
// batch_prefill supports flexible page sizes (not just multiples of 128)
|
||||
const bool need_128_aligned_page =
|
||||
(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API ||
|
||||
CK_TILE_FMHA_FWD_PAGEDKV_API);
|
||||
if(need_128_aligned_page && 0 < page_block_size && !(page_block_size % 128 == 0))
|
||||
{
|
||||
std::cerr << "only paged-kvcache block size divisible by 128 are currently supported"
|
||||
<< std::endl;
|
||||
@@ -972,9 +976,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem seqlen_q_buf(has_group_q_padding ? seqlen_qs.size() * sizeof(int32_t) : 0);
|
||||
// Buffers for key/value per-sequence logical (unpadded) lengths (used in batch mode with
|
||||
// kvcache or group mode with padding enabled)
|
||||
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || has_group_k_padding
|
||||
? seqlen_ks.size() * sizeof(int32_t)
|
||||
: 0);
|
||||
// batch_prefill (group+kvcache) also needs per-batch seqlen_k for VLLM_BLOCK_TABLE_2D
|
||||
const bool need_seqlen_k_buf = (mode == mode_enum::batch && use_kvcache) ||
|
||||
has_group_k_padding || (mode == mode_enum::group && use_kvcache);
|
||||
ck_tile::DeviceMem seqlen_k_buf(need_seqlen_k_buf ? seqlen_ks.size() * sizeof(int32_t) : 0);
|
||||
ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0
|
||||
: cuq_cum.size() * sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem cu_seqlen_kv_buf(
|
||||
@@ -1013,9 +1018,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data());
|
||||
cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data());
|
||||
seqlen_q_buf.ToDevice(has_group_q_padding ? seqlen_qs.data() : nullptr);
|
||||
seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || has_group_k_padding
|
||||
? seqlen_ks.data()
|
||||
: nullptr);
|
||||
seqlen_k_buf.ToDevice(need_seqlen_k_buf ? seqlen_ks.data() : nullptr);
|
||||
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
|
||||
rotary_cos_buf.ToDevice(rotary_cos_host.data());
|
||||
rotary_sin_buf.ToDevice(rotary_sin_host.data());
|
||||
@@ -1146,6 +1149,17 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
{
|
||||
traits.use_pagedkv = (0 < page_block_size);
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_batch_prefill_traits,
|
||||
std::decay_t<decltype(traits)>>)
|
||||
{
|
||||
traits.has_dropout = (p_drop > 0.0f);
|
||||
traits.qscale_type = qscale.type;
|
||||
traits.kv_memory_layout =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT;
|
||||
traits.kv_lookup_table =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D;
|
||||
traits.page_size = page_block_size;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1498,6 +1512,67 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
? seqlen_k_buf.GetDeviceBuffer()
|
||||
: nullptr);
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_batch_prefill_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
// Fields already set by the outer else block above:
|
||||
// bias_ptr, lse_ptr, o_ptr, seqlen_k, max_seqlen_q, scale_s,
|
||||
// logits_soft_cap, stride_bias/o, nhead/batch stride for bias/lse/o,
|
||||
// window_size_left/right, sink_size, mask_type.
|
||||
|
||||
// scale_p/scale_o: batch_prefill-specific fields absent from fmha_fwd_args.
|
||||
args.scale_p = 1.f;
|
||||
args.scale_o = 1.f;
|
||||
|
||||
// Dropout fields: the outer fmha_fwd_args branch sets these; set them here
|
||||
// for batch_prefill since it takes a separate inner branch.
|
||||
args.rand_val_ptr = randval_buf.GetDeviceBuffer();
|
||||
args.stride_randval = stride_randval;
|
||||
args.nhead_stride_randval = nhead_stride_randval;
|
||||
args.batch_stride_randval = batch_stride_randval;
|
||||
args.p_drop = p_drop;
|
||||
args.s_randval = s_randval;
|
||||
if(drop_prefs)
|
||||
args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(),
|
||||
drop_offset_buf.GetDeviceBuffer());
|
||||
else
|
||||
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
|
||||
|
||||
// Paged KV: LINEAR_LAYOUT + VLLM_BLOCK_TABLE_2D
|
||||
// block_table_buf: [batch, max_blocks_per_seq] of physical page ids
|
||||
// seqlen_k_buf: [batch] of per-batch seqlen_k values
|
||||
args.num_total_pages = max_num_page_blocks;
|
||||
args.page_block_size = page_block_size;
|
||||
args.kv_memory_layout =
|
||||
ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT;
|
||||
args.kv_lookup_table =
|
||||
ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D;
|
||||
args.kv_indptr = nullptr;
|
||||
args.kv_page_indices = block_table_buf.GetDeviceBuffer();
|
||||
args.kv_last_page_lens = nullptr;
|
||||
args.seqlen_k_ptr = seqlen_k_buf.GetDeviceBuffer();
|
||||
args.batch_stride_block_table = batch_stride_block_table;
|
||||
|
||||
// group mode required: seqstart_q is prefix-sum of per-batch seqlen_q
|
||||
args.seqstart_q_ptr = seqstart_q_buf.GetDeviceBuffer();
|
||||
|
||||
// batch_prefill LINEAR_LAYOUT strides for runner's K layout
|
||||
// [max_num_page_blocks, nhead_k, page_block_size, hdim]:
|
||||
// stride_k = hdim_q (token stride within one head's page slice)
|
||||
// nhead_stride_k = page_block_size * hdim_q (head stride)
|
||||
// batch_stride_k = nhead_k * page_block_size * hdim_q (page stride, already set)
|
||||
args.stride_k = hdim_q;
|
||||
args.nhead_stride_k = page_block_size * hdim_q;
|
||||
// V is row-major, same layout convention
|
||||
args.stride_v = hdim_v;
|
||||
args.nhead_stride_v = page_block_size * hdim_v;
|
||||
|
||||
// descale: not used for fp16/bf16
|
||||
args.q_descale_ptr = nullptr;
|
||||
args.k_descale_ptr = nullptr;
|
||||
args.v_descale_ptr = nullptr;
|
||||
args.nblock_stride_kv_block_descale = 0;
|
||||
args.nhead_stride_kv_block_descale = 0;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1524,6 +1599,21 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
|
||||
auto run_fwd = [&](const ck_tile::stream_config& sc) {
|
||||
#if CK_TILE_FMHA_FWD_BATCH_PREFILL_API
|
||||
// batch_prefill: group mode + paged KV, tested against the same CPU reference
|
||||
if(1 == num_splits && use_kvcache && mode == mode_enum::group)
|
||||
{
|
||||
fmha_batch_prefill_traits bp_traits;
|
||||
init_traits(bp_traits);
|
||||
|
||||
fmha_batch_prefill_args bp_args;
|
||||
init_args(bp_args);
|
||||
|
||||
const float ave_time = fmha_batch_prefill(bp_traits, bp_args, sc);
|
||||
if(ave_time >= 0.0f)
|
||||
return ave_time;
|
||||
}
|
||||
#endif // CK_TILE_FMHA_FWD_BATCH_PREFILL_API
|
||||
#if CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
if(1 == num_splits && use_kvcache)
|
||||
{
|
||||
@@ -1844,7 +1934,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
|
||||
}
|
||||
#endif
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API || \
|
||||
CK_TILE_FMHA_FWD_BATCH_PREFILL_API
|
||||
if(0 < page_block_size)
|
||||
{
|
||||
// clang-format off
|
||||
@@ -1895,7 +1986,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
});
|
||||
}
|
||||
#endif
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API || \
|
||||
CK_TILE_FMHA_FWD_BATCH_PREFILL_API
|
||||
if(0 < page_block_size)
|
||||
{
|
||||
if(is_v_rowmajor)
|
||||
|
||||
@@ -118,14 +118,11 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
|
||||
|
||||
if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
if(kargs.k_batch > 1)
|
||||
{
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(kargs.in_ptr,
|
||||
0,
|
||||
zeroing_size * sizeof(typename Types::EDataType),
|
||||
s_conf.stream_id_));
|
||||
}
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(kargs.in_ptr,
|
||||
0,
|
||||
zeroing_size * sizeof(typename Types::EDataType),
|
||||
s_conf.stream_id_));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -566,14 +566,12 @@ def parse_bwd_data_instances(instances, problem_name):
|
||||
if pipeline_version == "V6":
|
||||
print(f"Skipping instance {instance_id} with V6 since it's not supported yet.")
|
||||
continue
|
||||
|
||||
# Check vector sizes for A and B tensors - we cannot oversubscribe.
|
||||
num_tile_elements_a = m_per_xdl * k_per_xdl
|
||||
num_tile_elements_b = n_per_xdl * k_per_xdl
|
||||
max_vector_size_a = max(1, num_tile_elements_a // block_size)
|
||||
max_vector_size_b = max(1, num_tile_elements_b // block_size)
|
||||
a_scalar_per_vector = min(a_scalar_per_vector, max_vector_size_a)
|
||||
b_scalar_per_vector = min(b_scalar_per_vector, max_vector_size_b)
|
||||
if k_per_block > (warp_size * a_scalar_per_vector) or n_per_block > (warp_size * b_scalar_per_vector):
|
||||
print(f"Skipping instance {instance_id} with multiple warps per continous tile dim since it's not supported yet.")
|
||||
continue
|
||||
if a_scalar_per_vector > (m_per_block * k_per_block) // block_size or b_scalar_per_vector > (n_per_block * k_per_block) // block_size:
|
||||
print(f"Skipping instance {instance_id} because current scalar per vector exceedes tile size")
|
||||
continue
|
||||
|
||||
conv = ConvInstanceTemplateParams(
|
||||
spec,
|
||||
|
||||
@@ -28,8 +28,9 @@ namespace ck {
|
||||
|
||||
enum Activation
|
||||
{
|
||||
gelu_and_mul = 0,
|
||||
silu_and_mul = 1
|
||||
gelu_and_mul = 0,
|
||||
silu_and_mul = 1,
|
||||
swiglustep_and_mul = 2
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
|
||||
@@ -1592,6 +1592,25 @@ struct GridwiseMoeGemmBlockScale
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if constexpr(ActivationOperation == Activation::swiglustep_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weight;
|
||||
up = up * topk_weight;
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
gate = gate < 7.0f ? gate : 7.0f;
|
||||
up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f;
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
@@ -2118,6 +2137,25 @@ struct GridwiseMoeGemmBlockScale
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if constexpr(ActivationOperation == Activation::swiglustep_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
float up = c_thread_buf_up[cidx];
|
||||
if constexpr(MulRoutedWeight)
|
||||
{
|
||||
gate = gate * topk_weight;
|
||||
up = up * topk_weight;
|
||||
}
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
{
|
||||
gate *= 16;
|
||||
up *= 16;
|
||||
}
|
||||
tensor_operation::element_wise::Silu{}(gate, gate);
|
||||
gate = gate < 7.0f ? gate : 7.0f;
|
||||
up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f;
|
||||
c_thread_buf(cidx) = gate * up;
|
||||
}
|
||||
else if(ActivationOperation == Activation::gelu_and_mul)
|
||||
{
|
||||
float gate = c_thread_buf[cidx];
|
||||
|
||||
@@ -759,18 +759,19 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
kargs.sink_ptr != nullptr
|
||||
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
|
||||
: -numeric<float>::infinity();
|
||||
const index_t seqlen_k = [&]() {
|
||||
// WA i_batch capture structure binding before c++20
|
||||
const index_t seqlen_k = [&, i_batch_ = i_batch]() {
|
||||
if constexpr(kKVLookupTable ==
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
|
||||
{
|
||||
const int32_t page_start = kargs.page_table.kv_indptr[i_batch];
|
||||
const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1];
|
||||
const int32_t page_start = kargs.page_table.kv_indptr[i_batch_];
|
||||
const int32_t page_end = kargs.page_table.kv_indptr[i_batch_ + 1];
|
||||
const int32_t num_page_blocks = page_end - page_start;
|
||||
const int32_t last_page_len = [&]() {
|
||||
if constexpr(kPageBlockSize == 1)
|
||||
return static_cast<int32_t>(kPageBlockSize);
|
||||
else
|
||||
return kargs.page_table.kv_last_page_lens[i_batch];
|
||||
return kargs.page_table.kv_last_page_lens[i_batch_];
|
||||
}();
|
||||
return num_page_blocks > 0
|
||||
? static_cast<index_t>((num_page_blocks - 1) * kargs.page_block_size +
|
||||
@@ -780,21 +781,22 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
|
||||
{
|
||||
if(kargs.page_table.seqlen_k_ptr != nullptr)
|
||||
return static_cast<index_t>(kargs.page_table.seqlen_k_ptr[i_batch]);
|
||||
return static_cast<index_t>(kargs.page_table.seqlen_k_ptr[i_batch_]);
|
||||
else
|
||||
return kargs.seqlen_k;
|
||||
}
|
||||
}();
|
||||
const int32_t* page_idx = [&]() {
|
||||
// WA i_batch capture structure binding before c++20
|
||||
const int32_t* page_idx = [&, i_batch_ = i_batch]() {
|
||||
if constexpr(kKVLookupTable ==
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
|
||||
{
|
||||
return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch];
|
||||
return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch_];
|
||||
}
|
||||
else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D
|
||||
{
|
||||
return kargs.page_table.block_table_ptr +
|
||||
static_cast<long_index_t>(i_batch) *
|
||||
static_cast<long_index_t>(i_batch_) *
|
||||
kargs.page_table.batch_stride_block_table;
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -291,6 +291,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout;
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
|
||||
// For KV_BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift]
|
||||
// This avoids explicit P *= scale_p and v_descale /= scale_p operations
|
||||
@@ -546,11 +547,25 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
if constexpr(kHasSink)
|
||||
return mask.GetSinkTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
else
|
||||
{
|
||||
auto [start, end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
return ck_tile::make_tuple(0, start, end);
|
||||
}
|
||||
}();
|
||||
const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{});
|
||||
const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{});
|
||||
const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{});
|
||||
const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0);
|
||||
const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0;
|
||||
const auto num_total_loop =
|
||||
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop;
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
|
||||
@@ -576,7 +591,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
auto k_dram_block_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0});
|
||||
{kv_load_start, 0});
|
||||
|
||||
auto k_dist = Policy::template MakeKDramTileDistribution<Problem>();
|
||||
auto k_coord = k_dist.calculate_index();
|
||||
@@ -585,7 +600,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
// kPageBlockSize >= kN0: within-page offset only (SRD rebased per page via rebase_k_window)
|
||||
// kPageBlockSize < kN0: global offset, must fit int32
|
||||
statically_indexed_array<index_t, NRepeat> k_offsets;
|
||||
index_t current_seq_k = seqlen_k_start;
|
||||
index_t current_seq_k = kv_load_start;
|
||||
|
||||
// Load physical pages first, then compute offsets.
|
||||
// k_physical_pages can be reused for descale lookup later.
|
||||
@@ -668,11 +683,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
{bias_origin.at(number<0>{}), kv_load_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
|
||||
randval_dram_block_window_tmp, seqlen_k_start);
|
||||
randval_dram_block_window_tmp, kv_load_start);
|
||||
|
||||
auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
|
||||
auto v_coord = v_dist.calculate_index();
|
||||
@@ -895,7 +910,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
auto v_dram_window =
|
||||
make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
{0, kv_load_start}, // TODO: hdim split?
|
||||
v_dist,
|
||||
v_offsets,
|
||||
number<1>{}, // HsGatherDim
|
||||
@@ -1097,6 +1112,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
if(i_total_loops == num_sink_loop - 1)
|
||||
move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end});
|
||||
}
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
@@ -1108,19 +1128,36 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !variant.LogitsMask(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
auto apply_mask = [&](auto&& mask_func) {
|
||||
set_tile_if(s_acc,
|
||||
-numeric<SMPLComputeDataType>::infinity(),
|
||||
[&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return !mask_func(variant_params,
|
||||
block_indices.batch_idx,
|
||||
row,
|
||||
col,
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
});
|
||||
};
|
||||
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
apply_mask([&](auto&&... args) {
|
||||
return variant.LogitsSinkMask(
|
||||
std::forward<decltype(args)>(args)...);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_mask([&](auto&&... args) {
|
||||
return variant.LogitsMask(std::forward<decltype(args)>(args)...);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1297,12 +1334,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
{
|
||||
auto randval_ptr = reinterpret_cast<char*>(smem_ptr) +
|
||||
Policy::template GetSmemSizeKV<Problem>();
|
||||
index_t seq_offset = [&]() {
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
const bool in_sink_phase = (num_sink_loop > i_total_loops);
|
||||
if(i_total_loops == num_sink_loop)
|
||||
move_tile_window(randval_dram_window,
|
||||
{0, seqlen_k_start - sink_seq_end});
|
||||
return in_sink_phase
|
||||
? (kv_load_start + i_total_loops * kN0)
|
||||
: (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0);
|
||||
}
|
||||
else
|
||||
return seqlen_k_start + i_total_loops * kN0;
|
||||
}();
|
||||
dropout
|
||||
.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
|
||||
randval_ptr,
|
||||
seqlen_k_start + i_total_loops * kN0,
|
||||
p_compute,
|
||||
randval_dram_window);
|
||||
randval_ptr, seq_offset, p_compute, randval_dram_window);
|
||||
}
|
||||
|
||||
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
|
||||
@@ -1396,9 +1444,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
i_total_loops++;
|
||||
if(i_total_loops < num_total_loop)
|
||||
{
|
||||
current_seq_k += kN0;
|
||||
// For sink: after the last sink tile, jump K/V to seqlen_k_start;
|
||||
// otherwise advance by one normal tile.
|
||||
const index_t k_advance = [&]() -> index_t {
|
||||
if constexpr(kHasSink)
|
||||
return (i_total_loops == num_sink_loop)
|
||||
? (seqlen_k_start - sink_seq_end + kN0)
|
||||
: kN0;
|
||||
else
|
||||
return kN0;
|
||||
}();
|
||||
current_seq_k += k_advance;
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
move_tile_window(k_dram_block_window, {k_advance, 0});
|
||||
k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
|
||||
|
||||
// KV_BLOCKSCALE: reload physical pages for the new tile
|
||||
@@ -1427,6 +1485,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k);
|
||||
k_dram_window.update_page_idx(k_offsets);
|
||||
rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]);
|
||||
|
||||
// After sink→window transition (i_total_loops == num_sink_loop), V window
|
||||
// was advanced by kN0 (one normal iter), but current_seq_k jumped by k_advance
|
||||
// = seqlen_k_start - sink_seq_end + kN0 > kN0. Re-init V to current_seq_k.
|
||||
if constexpr(kHasSink)
|
||||
{
|
||||
if(i_total_loops == num_sink_loop && num_sink_loop > 0)
|
||||
{
|
||||
prefetch_v_physical_pages(number<0>{});
|
||||
update_v_offsets(number<0>{});
|
||||
v_dram_window.update_page_idx(v_offsets);
|
||||
rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(k1_loops >= 2 &&
|
||||
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
@@ -53,6 +53,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
BlockAttentionQuantScaleEnum QScaleEnum_,
|
||||
index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
|
||||
bool kSkipMinSeqlenQ_ = false, /* skip min seqlen q while chunked prefill */
|
||||
bool kHasSink_ = false, /* StreamLLM sink tokens */
|
||||
index_t kPageBlockSize_ = 1,
|
||||
BlockAttentionKVCacheMemoryLayoutEnum kKVMemoryLayout_ =
|
||||
BlockAttentionKVCacheMemoryLayoutEnum::VECTORIZED_LAYOUT,
|
||||
@@ -70,7 +71,7 @@ struct TileFmhaBatchPrefillTraits : public TileFmhaTraits<kPadSeqLenQ_,
|
||||
QScaleEnum_,
|
||||
kBlockPerCu_,
|
||||
kSkipMinSeqlenQ_,
|
||||
false>
|
||||
kHasSink_>
|
||||
{
|
||||
static constexpr auto kKVMemoryLayout = kKVMemoryLayout_;
|
||||
static constexpr auto kKVLookupTable = kKVLookupTable_;
|
||||
|
||||
@@ -22,6 +22,17 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
target_link_libraries(test_grouped_conv_bwd_data_scale PRIVATE gtest_main getopt::getopt utility device_grouped_conv3d_bwd_data_scale_instance)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(CK_EXPERIMENTAL_BUILDER)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data_tile test_grouped_convnd_bwd_data_tile.cpp)
|
||||
target_compile_options(test_grouped_convnd_bwd_data_tile PRIVATE -Wno-global-constructors -Wno-undef -Wno-c++20-compat)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_tile PRIVATE gtest_main getopt::getopt utility)
|
||||
if(TARGET device_grouped_conv_bwd_data_tile_instances)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_tile PRIVATE device_grouped_conv_bwd_data_tile_instances)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface.cpp)
|
||||
if(result EQUAL 0)
|
||||
|
||||
@@ -0,0 +1,258 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "profiler/grouped_convolution_backward_data_tile_algs.hpp"
|
||||
|
||||
static ck::index_t args_mask = 0xffff;
|
||||
static ck::index_t instance_index = -1;
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace ckp = ck_tile::builder::profiling;
|
||||
|
||||
template <ck_tile::index_t num_spatial_dim_,
|
||||
ckb::DataType data_type_,
|
||||
ckb::DataType acc_data_type_,
|
||||
ckb::TensorLayout in_layout_,
|
||||
ckb::TensorLayout wei_layout_,
|
||||
ckb::TensorLayout out_layout_>
|
||||
struct SignatureDetails
|
||||
{
|
||||
static constexpr ck_tile::index_t num_spatial_dim = num_spatial_dim_;
|
||||
static constexpr ckb::DataType data_type = data_type_;
|
||||
static constexpr ckb::DataType acc_data_type = acc_data_type_;
|
||||
static constexpr ckb::TensorLayout in_layout = in_layout_;
|
||||
static constexpr ckb::TensorLayout wei_layout = wei_layout_;
|
||||
static constexpr ckb::TensorLayout out_layout = out_layout_;
|
||||
};
|
||||
|
||||
template <typename SignatureDetailsType>
|
||||
class TestGroupedConvndBwdDataTile : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
static constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = SignatureDetailsType::num_spatial_dim,
|
||||
.direction = ckb::ConvDirection::BACKWARD_DATA,
|
||||
.data_type = SignatureDetailsType::data_type,
|
||||
.accumulation_data_type = SignatureDetailsType::acc_data_type,
|
||||
.input = {.config = {.layout = SignatureDetailsType::in_layout}},
|
||||
.weight = {.config = {.layout = SignatureDetailsType::wei_layout}},
|
||||
.output = {.config = {.layout = SignatureDetailsType::out_layout}}};
|
||||
|
||||
std::vector<ckt::Args<SIGNATURE>> conv_args;
|
||||
std::vector<std::string> split_ks{"1", "2"};
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
void Run()
|
||||
{
|
||||
ASSERT_FALSE(conv_args.empty());
|
||||
bool pass = true;
|
||||
for(size_t i = 0; i < conv_args.size(); i++)
|
||||
{
|
||||
for(auto& split_k : split_ks)
|
||||
{
|
||||
if((args_mask & (1 << i)) == 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
auto& args = conv_args[i];
|
||||
|
||||
auto inputs = alloc_inputs(args);
|
||||
auto outputs = alloc_outputs(args);
|
||||
ckt::init_tensor_buffer_uniform_int(
|
||||
inputs.get().weight, args.make_weight_descriptor(), -5, 5);
|
||||
ckt::init_tensor_buffer_uniform_int(
|
||||
inputs.get().output, args.make_output_descriptor(), -5, 5);
|
||||
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemset(outputs.get().input,
|
||||
0,
|
||||
args.make_input_descriptor().get_element_space_size_in_bytes()));
|
||||
|
||||
std::cout << args.make_input_descriptor() << std::endl;
|
||||
std::cout << args.make_weight_descriptor() << std::endl;
|
||||
std::cout << args.make_output_descriptor() << std::endl;
|
||||
[[maybe_unused]] auto&& [case_passed,
|
||||
avg_time,
|
||||
op_name,
|
||||
best_split_k,
|
||||
best_instance] =
|
||||
|
||||
ckp::run_grouped_conv_backward_data_tile_algs(
|
||||
args,
|
||||
split_k,
|
||||
-1,
|
||||
inputs.get(),
|
||||
outputs.get(),
|
||||
ck_tile::stream_config{nullptr, false /*time_kernel*/});
|
||||
|
||||
pass = pass && case_passed;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
void conv_args_append(std::size_t,
|
||||
std::size_t G,
|
||||
std::size_t N,
|
||||
std::size_t K,
|
||||
std::size_t C,
|
||||
const std::vector<std::size_t>& filter_spatial_lengths,
|
||||
const std::vector<std::size_t>& input_spatial_lengths,
|
||||
const std::vector<std::size_t>& conv_filter_strides,
|
||||
const std::vector<std::size_t>& conv_filter_dilations,
|
||||
const std::vector<std::size_t>& input_left_pads,
|
||||
const std::vector<std::size_t>& input_right_pads)
|
||||
{
|
||||
ckt::Args<SIGNATURE> args = {
|
||||
.lengths =
|
||||
{
|
||||
.batch_size = N,
|
||||
.groups = G,
|
||||
.input_channels = C,
|
||||
.output_channels = K,
|
||||
.image = ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
|
||||
input_spatial_lengths),
|
||||
.filter = ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
|
||||
filter_spatial_lengths),
|
||||
},
|
||||
.filter_strides = ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
|
||||
conv_filter_strides),
|
||||
.filter_dilation =
|
||||
ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
|
||||
conv_filter_dilations),
|
||||
.input_left_pad = ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
|
||||
input_left_pads),
|
||||
.input_right_pad =
|
||||
ckt::filter_extent_from_vector<SignatureDetailsType::num_spatial_dim>(
|
||||
input_right_pads),
|
||||
.a_elementwise_op = {},
|
||||
.b_elementwise_op = {},
|
||||
.cde_elementwise_op = {},
|
||||
};
|
||||
conv_args.push_back(args);
|
||||
}
|
||||
};
|
||||
|
||||
using KernelTypes2d = ::testing::Types<SignatureDetails<2,
|
||||
ckb::DataType::FP32,
|
||||
ckb::DataType::FP32,
|
||||
ckb::TensorLayout::NHWGC,
|
||||
ckb::TensorLayout::GKYXC,
|
||||
ckb::TensorLayout::NHWGK>,
|
||||
SignatureDetails<2,
|
||||
ckb::DataType::FP16,
|
||||
ckb::DataType::FP32,
|
||||
ckb::TensorLayout::NHWGC,
|
||||
ckb::TensorLayout::GKYXC,
|
||||
ckb::TensorLayout::NHWGK>,
|
||||
SignatureDetails<2,
|
||||
ckb::DataType::BF16,
|
||||
ckb::DataType::FP32,
|
||||
ckb::TensorLayout::NHWGC,
|
||||
ckb::TensorLayout::GKYXC,
|
||||
ckb::TensorLayout::NHWGK>>;
|
||||
|
||||
using KernelTypes3d = ::testing::Types<SignatureDetails<3,
|
||||
ckb::DataType::FP32,
|
||||
ckb::DataType::FP32,
|
||||
ckb::TensorLayout::NDHWGC,
|
||||
ckb::TensorLayout::GKZYXC,
|
||||
ckb::TensorLayout::NDHWGK>,
|
||||
SignatureDetails<3,
|
||||
ckb::DataType::FP16,
|
||||
ckb::DataType::FP32,
|
||||
ckb::TensorLayout::NDHWGC,
|
||||
ckb::TensorLayout::GKZYXC,
|
||||
ckb::TensorLayout::NDHWGK>,
|
||||
SignatureDetails<3,
|
||||
ckb::DataType::BF16,
|
||||
ckb::DataType::FP32,
|
||||
ckb::TensorLayout::NDHWGC,
|
||||
ckb::TensorLayout::GKZYXC,
|
||||
ckb::TensorLayout::NDHWGK>>;
|
||||
|
||||
template <typename SignatureDetailsType>
|
||||
class TestGroupedConvndBwdDataTile2d : public TestGroupedConvndBwdDataTile<SignatureDetailsType>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename SignatureDetailsType>
|
||||
class TestGroupedConvndBwdDataTile3d : public TestGroupedConvndBwdDataTile<SignatureDetailsType>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataTile2d, KernelTypes2d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataTile3d, KernelTypes3d);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataTile2d, Test2D)
|
||||
{
|
||||
this->conv_args.clear();
|
||||
|
||||
// GroupedGemmGroupsNum = 4, ZTilde * YTilde * XTilde = 4, MaxGroupedGemmGroupsNum = 32
|
||||
this->conv_args_append(2, 2, 2, 16, 16, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1});
|
||||
// GroupedGemmGroupsNum = 9, ZTilde * YTilde * XTilde = 36, MaxGroupedGemmGroupsNum = 32
|
||||
this->conv_args_append(2, 2, 2, 16, 16, {3, 3}, {28, 28}, {6, 6}, {1, 1}, {1, 1}, {1, 1});
|
||||
// GroupedGemmGroupsNum = 36, ZTilde * YTilde * XTilde = 36, MaxGroupedGemmGroupsNum = 32
|
||||
this->conv_args_append(2, 2, 2, 16, 16, {6, 6}, {28, 28}, {6, 6}, {1, 1}, {1, 1}, {1, 1});
|
||||
// GroupedGemmGroupsNum = 32, ZTilde * YTilde * XTilde = 32, MaxGroupedGemmGroupsNum = 32
|
||||
this->conv_args_append(2, 2, 2, 16, 16, {4, 8}, {28, 28}, {4, 8}, {1, 1}, {1, 1}, {1, 1});
|
||||
this->conv_args_append(2, 2, 2, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
|
||||
this->conv_args_append(2, 2, 2, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
|
||||
this->conv_args_append(2, 2, 2, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0});
|
||||
this->conv_args_append(2, 2, 2, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0});
|
||||
this->conv_args_append(2, 2, 2, 32, 32, {2, 2}, {12, 12}, {3, 3}, {1, 1}, {0, 0}, {0, 0});
|
||||
this->conv_args_append(2, 2, 2, 32, 32, {2, 2}, {12, 12}, {2, 2}, {2, 2}, {0, 0}, {0, 0});
|
||||
this->conv_args_append(2, 1, 6, 448, 896, {1, 1}, {118, 182}, {2, 2}, {1, 1}, {0, 0}, {0, 0});
|
||||
this->conv_args_append(2, 1, 1, 1, 32, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
|
||||
this->conv_args_append(2, 1, 1, 64, 3, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
|
||||
this->conv_args_append(2, 1, 1, 1, 1, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1});
|
||||
this->template Run<2>();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataTile3d, Test3D)
|
||||
{
|
||||
this->conv_args.clear();
|
||||
this->conv_args_append(
|
||||
3, 2, 2, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0});
|
||||
this->conv_args_append(
|
||||
3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
|
||||
this->conv_args_append(
|
||||
3, 2, 2, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0});
|
||||
this->conv_args_append(
|
||||
3, 2, 2, 32, 32, {1, 2, 2}, {1, 12, 12}, {1, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0});
|
||||
this->conv_args_append(
|
||||
3, 2, 2, 32, 32, {1, 2, 2}, {1, 12, 12}, {1, 2, 2}, {1, 2, 2}, {0, 0, 0}, {0, 0, 0});
|
||||
this->conv_args_append(
|
||||
3, 1, 1, 1, 32, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
|
||||
this->conv_args_append(
|
||||
3, 1, 1, 64, 3, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
|
||||
this->conv_args_append(
|
||||
3, 1, 1, 1, 1, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1});
|
||||
this->template Run<3>();
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
if(argc == 1) {}
|
||||
else if(argc == 3)
|
||||
{
|
||||
args_mask = strtol(argv[1], nullptr, 0);
|
||||
instance_index = atoi(argv[2]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Usage of " << argv[0] << std::endl;
|
||||
std::cout << "Arg1,2: args_mask instance_index(-1 means all)" << std::endl;
|
||||
}
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
Reference in New Issue
Block a user