From b75afb4274389d9c7304db122bc4c7fc1349efe1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=87=91=E9=BB=84=E8=89=B2=E8=91=A1=E8=90=84=E7=90=83?= =?UTF-8?q?=E5=90=9B=E5=90=9B?= Date: Tue, 21 Apr 2026 07:25:52 +0000 Subject: [PATCH 1/4] [rocm-libraries] ROCm/rocm-libraries#6118 (commit 2c7dcf7) projects/composablekernel: add SwigluStep support for MoE blockscale (#6118) ## Summary - add `swiglustep_and_mul` to the composablekernel MoE blockscale activation enum - implement the corresponding blockscale epilogue path for `SwigluStep` - keep existing `silu` and `gelu` paths unchanged ## Scope This PR covers the classic composablekernel blockscale MoE path under `projects/composablekernel`. This is separate from the `ck_tile` / FlatMM path being discussed in ROCm/rocm-libraries#5992. ## Motivation `Step-3.5-Flash-FP8` uses `SwigluStep` in its MoE MLP path. The dependent AITER change needs native support for this activation in the classic composablekernel MoE blockscale path. ## Validation - patch is limited to two composablekernel files under `projects/composablekernel` - existing `silu` / `gelu` paths are unchanged - dependent AITER runtime validation hit the classic CK 2-stage path with AITER MoE enabled --- .../gridwise_gemm_xdl_cshuffle_common.hpp | 5 ++- .../gpu/grid/gridwise_moe_gemm_blockscale.hpp | 38 +++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp index 6e047dd64a..2f9a9cd21b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp @@ -28,8 +28,9 @@ namespace ck { enum Activation { - gelu_and_mul = 0, - silu_and_mul = 1 + gelu_and_mul = 0, + silu_and_mul = 1, + swiglustep_and_mul = 2 }; template , pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + gate = gate < 7.0f ? gate : 7.0f; + up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f; + c_thread_buf(cidx) = gate * up; + } else if(ActivationOperation == Activation::gelu_and_mul) { float gate = c_thread_buf[cidx]; @@ -2118,6 +2137,25 @@ struct GridwiseMoeGemmBlockScale tensor_operation::element_wise::Silu{}(gate, gate); c_thread_buf(cidx) = gate * up; } + else if constexpr(ActivationOperation == Activation::swiglustep_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weight; + up = up * topk_weight; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + gate = gate < 7.0f ? gate : 7.0f; + up = up < 7.0f ? (up > -7.0f ? up : -7.0f) : 7.0f; + c_thread_buf(cidx) = gate * up; + } else if(ActivationOperation == Activation::gelu_and_mul) { float gate = c_thread_buf[cidx]; From d22aafb48baf3ba2960751b3e08be4a66ec487b7 Mon Sep 17 00:00:00 2001 From: Linjun-AMD Date: Tue, 21 Apr 2026 11:05:12 +0000 Subject: [PATCH 2/4] [rocm-libraries] ROCm/rocm-libraries#6479 (commit 0705c2d) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CK][fmha] Add StreamLLM sink support to batch_prefill pipeline (#6479) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation The existing paged-KV attention pipelines (pagedkv, splitkv) support StreamLLM-style sink tokens — a fixed set of initial tokens kept in attention alongside the sliding window. The `batch_prefill` pipeline (chunked-prefill with VLLM-style block tables) previously hardcoded `kHasSink = false`, making it incompatible with sink-based attention patterns in LLM serving scenarios. This PR extends `batch_prefill` to support `kHasSink` and wires it into `fmha_fwd_runner` for validation against the existing CPU reference. ## Technical Details **Pipeline** (`block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp`): - When `kHasSink`, the K/V loop splits into a sink phase [0, sink_seq_end) and a window phase [seqlen_k_start, seqlen_k_end), mirroring pagedkv. - K advance at the sink→window transition jumps `seqlen_k_start - sink_seq_end + kN0` to bridge the gap. - V scatter-gather offsets are re-initialized at the transition to fix a window mismatch bug: V was lagging kN0 behind K after the large jump, loading from the wrong sequence position. - Bias window, dropout seq_offset, and mask type (LogitsSinkMask) updated for sink-awareness. **Traits / codegen** (`tile_fmha_traits.hpp`, `fmha_fwd.hpp`, `fmha_batch_prefill.py`): - `TileFmhaBatchPrefillTraits` gains `kHasSink_` (was hardcoded `false`). - Codegen adds `F_sink` field; skips batch-mode kernels (group mode required). - CMake test filter broadened from 9 → 33 instances covering fp16/bf16 × mask/nmask × lse/nlse × sink/nsink. **Runner** (`fmha_fwd_runner.hpp`, `CMakeLists.txt`): - `fmha_batch_prefill()` dispatched from `run_fwd` when: group mode + paged KV + num_splits == 1. - K/V strides corrected for runner's [num_pages, nhead_k, page_block_size, hdim] layout. - `page_block_size % 128` check relaxed: batch_prefill supports ps=16. - CPU reference paged-KV reordering guards extended with `CK_TILE_FMHA_FWD_BATCH_PREFILL_API`. ## Test Plan Build with `-DFMHA_FWD_ENABLE_APIS="fwd;batch_prefill"`, run `tile_example_fmha_fwd` in group mode with page_block_size=16. Test matrix: - Mask: no-mask, causal, sliding window - Sink: nsink, sink=1..128 - dtype: fp16, bf16 - LSE output: on/off - seqlen ∈ {512,1024,2048,4096} × window ∈ {32,256,512,1024} - GQA, chunked prefill, large batch×seqlen - page_block_size: 16, 32 ## Test Result 171 test cases, all valid:y: - nmask + nsink: ✓ - causal + nsink: ✓ - causal + sink=8: ✓ - sliding window + sink=8 (d=128, d=256): ✓ - bf16, LSE output, GQA: ✓ ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- example/ck_tile/01_fmha/CMakeLists.txt | 10 +- .../01_fmha/codegen/ops/fmha_batch_prefill.py | 45 ++++-- example/ck_tile/01_fmha/fmha_fwd.hpp | 3 +- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 112 +++++++++++++-- .../fmha/kernel/fmha_batch_prefill_kernel.hpp | 18 +-- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 129 ++++++++++++++---- .../ops/fmha/pipeline/tile_fmha_traits.hpp | 3 +- 7 files changed, 261 insertions(+), 59 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 35afb1181e..02d31d2324 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -10,7 +10,7 @@ if(NOT INST_TARGETS) endif() # validate user-specified fmha_fwd API list -set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill") +set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill;batch_prefill") set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".") if(BUILD_TESTING) @@ -48,7 +48,6 @@ set(FMHA_FWD_CODE_GEN_COMMON_ARGS --targets ${FMHA_TARGETS_ARG} --api ${FMHA_FWD_APIS} --optdim 32,64,80,128,256 - # --filter fmha_fwd... ) set(FMHA_BWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py @@ -174,6 +173,13 @@ else() list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0) endif() +# conditionally enable call to the batch_prefill API in fmha_fwd example and tests +if("batch_prefill" IN_LIST FMHA_FWD_ENABLE_APIS) + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_BATCH_PREFILL_API=1) +else() + list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_BATCH_PREFILL_API=0) +endif() + # conditionally specify the use of OCP_FP8 if(CK_USE_OCP_FP8) list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py index 35e8c1be49..7c3efb9c18 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_prefill.py @@ -84,6 +84,7 @@ using fmha_trait_{F_idx} = ck_tile::TileFmhaBatchPrefillTraits<{F_spad}, {F_qscale}, {F_occupancy}, false, + {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; @@ -124,7 +125,7 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel; using trait_{F_idx} = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, - {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; + {F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; #include @@ -201,9 +202,9 @@ FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v }} """ -FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && +FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint}) && (t.kv_memory_layout == {F_kv_memory_layout}) && (t.kv_lookup_table == {F_kv_lookup_table}) && (t.page_size == {F_page_size})) {{ - using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; + using trait_ = fmha_fwd_batch_prefill_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false, false, {F_sink}, {F_page_size}, {F_kv_memory_layout}, {F_kv_lookup_table}>; return fmha_batch_prefill_(s, a); }} """ @@ -247,6 +248,7 @@ class FmhaFwdApiTrait: skpad: str dpad: str dvpad: str + sink: str # t/f constraint: CppConstraint kv_memory_layout: str kv_lookup_table: str @@ -343,6 +345,7 @@ class FmhaFwdPipeline: F_dropout: str # F_qscale: str # no/pertensor F_mask: str # value from MASK_MAP + F_sink: str # t/f (StreamLLM sink tokens) F_kv_memory_layout: str # F_kv_lookup_table: str # F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint()) @@ -406,6 +409,11 @@ class FmhaFwdPipeline: else: n += "_nqscale" + if self.F_sink == "t": + n += "_sink" + else: + n += "_nsink" + n += "_" + self.F_kv_memory_layout + "_" + self.F_kv_lookup_table return n @@ -472,6 +480,7 @@ class FmhaFwdApiPool: trait.kv_lookup_table ], F_page_size=trait.page_size, + F_sink=BOOL_MAP[trait.sink], ) if_j = "if" if j == 0 else "else if" per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format( @@ -578,6 +587,7 @@ class FmhaFwdKernel: F_mode=MODE_MAP[self.F_mode], F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag], F_page_size=self.F_page_size, + F_sink=BOOL_MAP[self.F_pipeline.F_sink], ) @property @@ -617,6 +627,7 @@ class FmhaFwdKernel: skpad=self.F_pipeline.F_skpad, dpad=self.F_pipeline.F_dpad, dvpad=self.F_pipeline.F_dvpad, + sink=self.F_pipeline.F_sink, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, kv_memory_layout=self.F_pipeline.F_kv_memory_layout, kv_lookup_table=self.F_pipeline.F_kv_lookup_table, @@ -655,6 +666,7 @@ class KernelComponentFactory: bias, lse, dropout, + sink, kv_memory_layout, kv_lookup_table, ) in itertools.product( @@ -663,12 +675,13 @@ class KernelComponentFactory: BIAS_MAP.keys(), ["t", "f"], ["t", "f"], + ["t", "f"], SUPPORTED_KV_MEMORY_LAYOUT, SUPPORTED_KV_LOOKUP_TABLE, ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, sink, kv_memory_layout, kv_lookup_table)) # fmt: skip elif dtype in ["fp8bf16"]: - # no need lse/dropout kernels + # no need lse/dropout/sink kernels for ( logits, qscale, @@ -684,7 +697,7 @@ class KernelComponentFactory: SUPPORTED_KV_MEMORY_LAYOUT, SUPPORTED_KV_LOOKUP_TABLE, ): - pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, kv_memory_layout, kv_lookup_table)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", kv_memory_layout, kv_lookup_table)) # fmt: skip else: assert False return pipelines @@ -701,20 +714,34 @@ class CustomFactory(KernelComponentFactory): def get_fwd_blobs( - kernel_filter: Optional[str], receipt, optdim_list, mask_impl + kernel_filter: Optional[str], receipt, optdim_list, mask_impl, + targets: Optional[List[str]] = None ) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: + # batch_prefill pipeline uses gfx9-specific async scatter-gather buffer addressing + # (amd_buffer_addressing.hpp raw buffer loads) that is not compatible with + # non-gfx9 architectures (gfx11/gfx12/gfx10 are wave32 and use different + # buffer instruction formats). Skip all batch_prefill kernels for non-gfx9 targets. + has_non_gfx9 = targets is not None and any( + not t.startswith("gfx9") for t in targets + ) # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future gen = list() api_pool = FmhaFwdApiPool(mask_impl) + if has_non_gfx9: + return api_pool, gen + for dtype in FWD_DTYPE_MAP.keys(): d = CustomFactory.get_hdim_tile_size_dict(dtype) if d is None: continue # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()): + # batch_prefill pipeline requires group mode (static_assert in pipeline problem) + if mode != "group": + continue for tile, pipeline in itertools.product( tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl) ): @@ -829,7 +856,7 @@ def write_blobs( optdim_list, mask_impl, ) -> None: - api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets) for kernel in kernels: write_single_fwd_kernel(kernel, output_dir) write_fwd_api(api_pool, output_dir) @@ -844,7 +871,7 @@ def list_blobs( mask_impl, ) -> None: with file_path.open("a") as f: - _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl) + _, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl, targets) for kernel in kernels: f.write((file_path.parent / GEN_DIR / kernel.filename).as_posix() + "\n") f.write((file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME).as_posix() + "\n") diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 7d7d01bd05..6c842def58 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1452,6 +1452,7 @@ template + kHasSink_> { static constexpr auto kKVMemoryLayout = kKVMemoryLayout_; static constexpr auto kKVLookupTable = kKVLookupTable_; diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 40b8006381..e01555193f 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -387,7 +387,7 @@ fwd_result fmha_fwd_run(mode_enum mode, } #if(!(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || \ - CK_TILE_FMHA_FWD_PAGEDKV_API)) + CK_TILE_FMHA_FWD_PAGEDKV_API || CK_TILE_FMHA_FWD_BATCH_PREFILL_API)) if(0 < page_block_size) { std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option" @@ -395,7 +395,11 @@ fwd_result fmha_fwd_run(mode_enum mode, page_block_size = 0; } #endif - if(!(page_block_size % 128 == 0)) + // batch_prefill supports flexible page sizes (not just multiples of 128) + const bool need_128_aligned_page = + (CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || + CK_TILE_FMHA_FWD_PAGEDKV_API); + if(need_128_aligned_page && 0 < page_block_size && !(page_block_size % 128 == 0)) { std::cerr << "only paged-kvcache block size divisible by 128 are currently supported" << std::endl; @@ -972,9 +976,10 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem seqlen_q_buf(has_group_q_padding ? seqlen_qs.size() * sizeof(int32_t) : 0); // Buffers for key/value per-sequence logical (unpadded) lengths (used in batch mode with // kvcache or group mode with padding enabled) - ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || has_group_k_padding - ? seqlen_ks.size() * sizeof(int32_t) - : 0); + // batch_prefill (group+kvcache) also needs per-batch seqlen_k for VLLM_BLOCK_TABLE_2D + const bool need_seqlen_k_buf = (mode == mode_enum::batch && use_kvcache) || + has_group_k_padding || (mode == mode_enum::group && use_kvcache); + ck_tile::DeviceMem seqlen_k_buf(need_seqlen_k_buf ? seqlen_ks.size() * sizeof(int32_t) : 0); ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0 : cuq_cum.size() * sizeof(ck_tile::index_t)); ck_tile::DeviceMem cu_seqlen_kv_buf( @@ -1013,9 +1018,7 @@ fwd_result fmha_fwd_run(mode_enum mode, cu_seqlen_q_buf.ToDevice(cuq_cum.empty() ? nullptr : cuq_cum.data()); cu_seqlen_kv_buf.ToDevice(cukv_cum.empty() ? nullptr : cukv_cum.data()); seqlen_q_buf.ToDevice(has_group_q_padding ? seqlen_qs.data() : nullptr); - seqlen_k_buf.ToDevice((mode == mode_enum::batch && use_kvcache) || has_group_k_padding - ? seqlen_ks.data() - : nullptr); + seqlen_k_buf.ToDevice(need_seqlen_k_buf ? seqlen_ks.data() : nullptr); cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr); rotary_cos_buf.ToDevice(rotary_cos_host.data()); rotary_sin_buf.ToDevice(rotary_sin_host.data()); @@ -1146,6 +1149,17 @@ fwd_result fmha_fwd_run(mode_enum mode, { traits.use_pagedkv = (0 < page_block_size); } + else if constexpr(std::is_same_v>) + { + traits.has_dropout = (p_drop > 0.0f); + traits.qscale_type = qscale.type; + traits.kv_memory_layout = + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT; + traits.kv_lookup_table = + ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D; + traits.page_size = page_block_size; + } } }; @@ -1498,6 +1512,67 @@ fwd_result fmha_fwd_run(mode_enum mode, ? seqlen_k_buf.GetDeviceBuffer() : nullptr); } + else if constexpr(std::is_same_v>) + { + // Fields already set by the outer else block above: + // bias_ptr, lse_ptr, o_ptr, seqlen_k, max_seqlen_q, scale_s, + // logits_soft_cap, stride_bias/o, nhead/batch stride for bias/lse/o, + // window_size_left/right, sink_size, mask_type. + + // scale_p/scale_o: batch_prefill-specific fields absent from fmha_fwd_args. + args.scale_p = 1.f; + args.scale_o = 1.f; + + // Dropout fields: the outer fmha_fwd_args branch sets these; set them here + // for batch_prefill since it takes a separate inner branch. + args.rand_val_ptr = randval_buf.GetDeviceBuffer(); + args.stride_randval = stride_randval; + args.nhead_stride_randval = nhead_stride_randval; + args.batch_stride_randval = batch_stride_randval; + args.p_drop = p_drop; + args.s_randval = s_randval; + if(drop_prefs) + args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(), + drop_offset_buf.GetDeviceBuffer()); + else + args.drop_seed_offset = std::make_pair(drop_seed, drop_offset); + + // Paged KV: LINEAR_LAYOUT + VLLM_BLOCK_TABLE_2D + // block_table_buf: [batch, max_blocks_per_seq] of physical page ids + // seqlen_k_buf: [batch] of per-batch seqlen_k values + args.num_total_pages = max_num_page_blocks; + args.page_block_size = page_block_size; + args.kv_memory_layout = + ck_tile::BlockAttentionKVCacheMemoryLayoutEnum::LINEAR_LAYOUT; + args.kv_lookup_table = + ck_tile::BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D; + args.kv_indptr = nullptr; + args.kv_page_indices = block_table_buf.GetDeviceBuffer(); + args.kv_last_page_lens = nullptr; + args.seqlen_k_ptr = seqlen_k_buf.GetDeviceBuffer(); + args.batch_stride_block_table = batch_stride_block_table; + + // group mode required: seqstart_q is prefix-sum of per-batch seqlen_q + args.seqstart_q_ptr = seqstart_q_buf.GetDeviceBuffer(); + + // batch_prefill LINEAR_LAYOUT strides for runner's K layout + // [max_num_page_blocks, nhead_k, page_block_size, hdim]: + // stride_k = hdim_q (token stride within one head's page slice) + // nhead_stride_k = page_block_size * hdim_q (head stride) + // batch_stride_k = nhead_k * page_block_size * hdim_q (page stride, already set) + args.stride_k = hdim_q; + args.nhead_stride_k = page_block_size * hdim_q; + // V is row-major, same layout convention + args.stride_v = hdim_v; + args.nhead_stride_v = page_block_size * hdim_v; + + // descale: not used for fp16/bf16 + args.q_descale_ptr = nullptr; + args.k_descale_ptr = nullptr; + args.v_descale_ptr = nullptr; + args.nblock_stride_kv_block_descale = 0; + args.nhead_stride_kv_block_descale = 0; + } } }; @@ -1524,6 +1599,21 @@ fwd_result fmha_fwd_run(mode_enum mode, } auto run_fwd = [&](const ck_tile::stream_config& sc) { +#if CK_TILE_FMHA_FWD_BATCH_PREFILL_API + // batch_prefill: group mode + paged KV, tested against the same CPU reference + if(1 == num_splits && use_kvcache && mode == mode_enum::group) + { + fmha_batch_prefill_traits bp_traits; + init_traits(bp_traits); + + fmha_batch_prefill_args bp_args; + init_args(bp_args); + + const float ave_time = fmha_batch_prefill(bp_traits, bp_args, sc); + if(ave_time >= 0.0f) + return ave_time; + } +#endif // CK_TILE_FMHA_FWD_BATCH_PREFILL_API #if CK_TILE_FMHA_FWD_PAGEDKV_API if(1 == num_splits && use_kvcache) { @@ -1844,7 +1934,8 @@ fwd_result fmha_fwd_run(mode_enum mode, q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); } #endif -#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API || \ + CK_TILE_FMHA_FWD_BATCH_PREFILL_API if(0 < page_block_size) { // clang-format off @@ -1895,7 +1986,8 @@ fwd_result fmha_fwd_run(mode_enum mode, }); } #endif -#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API +#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API || \ + CK_TILE_FMHA_FWD_BATCH_PREFILL_API if(0 < page_block_size) { if(is_v_rowmajor) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp index c6628f66be..a523acd291 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp @@ -759,18 +759,19 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel kargs.sink_ptr != nullptr ? (*(static_cast(kargs.sink_ptr) + i_nhead)) / kargs.scale_s : -numeric::infinity(); - const index_t seqlen_k = [&]() { + // WA i_batch capture structure binding before c++20 + const index_t seqlen_k = [&, i_batch_ = i_batch]() { if constexpr(kKVLookupTable == BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D) { - const int32_t page_start = kargs.page_table.kv_indptr[i_batch]; - const int32_t page_end = kargs.page_table.kv_indptr[i_batch + 1]; + const int32_t page_start = kargs.page_table.kv_indptr[i_batch_]; + const int32_t page_end = kargs.page_table.kv_indptr[i_batch_ + 1]; const int32_t num_page_blocks = page_end - page_start; const int32_t last_page_len = [&]() { if constexpr(kPageBlockSize == 1) return static_cast(kPageBlockSize); else - return kargs.page_table.kv_last_page_lens[i_batch]; + return kargs.page_table.kv_last_page_lens[i_batch_]; }(); return num_page_blocks > 0 ? static_cast((num_page_blocks - 1) * kargs.page_block_size + @@ -780,21 +781,22 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D { if(kargs.page_table.seqlen_k_ptr != nullptr) - return static_cast(kargs.page_table.seqlen_k_ptr[i_batch]); + return static_cast(kargs.page_table.seqlen_k_ptr[i_batch_]); else return kargs.seqlen_k; } }(); - const int32_t* page_idx = [&]() { + // WA i_batch capture structure binding before c++20 + const int32_t* page_idx = [&, i_batch_ = i_batch]() { if constexpr(kKVLookupTable == BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D) { - return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch]; + return kargs.page_table.kv_page_indices + kargs.page_table.kv_indptr[i_batch_]; } else // BlockAttentionKVCacheLookupTableEnum::VLLM_BLOCK_TABLE_2D { return kargs.page_table.block_table_ptr + - static_cast(i_batch) * + static_cast(i_batch_) * kargs.page_table.batch_stride_block_table; } }(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index a8b94b6e41..4f2d3d58c2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -291,6 +291,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync static constexpr bool kHasDropout = Problem::kHasDropout; static constexpr auto kKVMemoryLayout = Problem::kKVMemoryLayout; static constexpr auto QScaleEnum = Problem::QScaleEnum; + static constexpr bool kHasSink = Problem::kHasSink; // For KV_BLOCKSCALE: shift value for exp2(x + shift) to scale P to [0, 2^shift] // This avoids explicit P *= scale_p and v_descale /= scale_p operations @@ -546,11 +547,25 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync } __builtin_amdgcn_sched_barrier(0); - const auto q_origin = q_dram_window.get_window_origin(); - const auto [seqlen_k_start, seqlen_k_end] = - mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); - - const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + const auto q_origin = q_dram_window.get_window_origin(); + const auto tile_range_result = [&mask, &q_origin]() { + if constexpr(kHasSink) + return mask.GetSinkTileRangeAlongX( + q_origin.at(number<0>{}), number{}, number{}); + else + { + auto [start, end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + return ck_tile::make_tuple(0, start, end); + } + }(); + const auto sink_seq_end = tile_range_result.get(ck_tile::number<0>{}); + const auto seqlen_k_start = tile_range_result.get(ck_tile::number<1>{}); + const auto seqlen_k_end = tile_range_result.get(ck_tile::number<2>{}); + const auto num_sink_loop = integer_divide_ceil(sink_seq_end, kN0); + const auto kv_load_start = (sink_seq_end == 0 && seqlen_k_start > 0) ? seqlen_k_start : 0; + const auto num_total_loop = + integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0) + num_sink_loop; // check early exit if no work to do if constexpr(FmhaMask::IsMasking || kPadSeqLenK) @@ -576,7 +591,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}); + {kv_load_start, 0}); auto k_dist = Policy::template MakeKDramTileDistribution(); auto k_coord = k_dist.calculate_index(); @@ -585,7 +600,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync // kPageBlockSize >= kN0: within-page offset only (SRD rebased per page via rebase_k_window) // kPageBlockSize < kN0: global offset, must fit int32 statically_indexed_array k_offsets; - index_t current_seq_k = seqlen_k_start; + index_t current_seq_k = kv_load_start; // Load physical pages first, then compute offsets. // k_physical_pages can be reused for descale lookup later. @@ -668,11 +683,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto bias_dram_window = make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_window_lengths(), - {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + {bias_origin.at(number<0>{}), kv_load_start}, // M/N Policy::template MakeBiasDramTileDistribution()); auto randval_dram_window = dropout.template MakeRandvalDramWindow( - randval_dram_block_window_tmp, seqlen_k_start); + randval_dram_block_window_tmp, kv_load_start); auto v_dist = Policy::template MakeVDramTileDistribution(); auto v_coord = v_dist.calculate_index(); @@ -895,7 +910,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync auto v_dram_window = make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(), v_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_k_start}, // TODO: hdim split? + {0, kv_load_start}, // TODO: hdim split? v_dist, v_offsets, number<1>{}, // HsGatherDim @@ -1097,6 +1112,11 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync #endif } } + if constexpr(kHasSink) + { + if(i_total_loops == num_sink_loop - 1) + move_tile_window(bias_dram_window, {0, seqlen_k_start - sink_seq_end}); + } move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { @@ -1108,19 +1128,36 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync if(need_perpixel_check) { - set_tile_if( - s_acc, -numeric::infinity(), [&](auto tile_idx) { - const auto row = - q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = - k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return !variant.LogitsMask(variant_params, - block_indices.batch_idx, - row, - col, - block_indices.qo_head_idx, - block_indices.kv_head_idx); + auto apply_mask = [&](auto&& mask_func) { + set_tile_if(s_acc, + -numeric::infinity(), + [&](auto tile_idx) { + const auto row = + q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = + k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return !mask_func(variant_params, + block_indices.batch_idx, + row, + col, + block_indices.qo_head_idx, + block_indices.kv_head_idx); + }); + }; + + if constexpr(kHasSink) + { + apply_mask([&](auto&&... args) { + return variant.LogitsSinkMask( + std::forward(args)...); }); + } + else + { + apply_mask([&](auto&&... args) { + return variant.LogitsMask(std::forward(args)...); + }); + } } } @@ -1297,12 +1334,23 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync { auto randval_ptr = reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeKV(); + index_t seq_offset = [&]() { + if constexpr(kHasSink) + { + const bool in_sink_phase = (num_sink_loop > i_total_loops); + if(i_total_loops == num_sink_loop) + move_tile_window(randval_dram_window, + {0, seqlen_k_start - sink_seq_end}); + return in_sink_phase + ? (kv_load_start + i_total_loops * kN0) + : (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0); + } + else + return seqlen_k_start + i_total_loops * kN0; + }(); dropout .template Run( - randval_ptr, - seqlen_k_start + i_total_loops * kN0, - p_compute, - randval_dram_window); + randval_ptr, seq_offset, p_compute, randval_dram_window); } #if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN @@ -1396,9 +1444,19 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync i_total_loops++; if(i_total_loops < num_total_loop) { - current_seq_k += kN0; + // For sink: after the last sink tile, jump K/V to seqlen_k_start; + // otherwise advance by one normal tile. + const index_t k_advance = [&]() -> index_t { + if constexpr(kHasSink) + return (i_total_loops == num_sink_loop) + ? (seqlen_k_start - sink_seq_end + kN0) + : kN0; + else + return kN0; + }(); + current_seq_k += k_advance; // move K tile windows - move_tile_window(k_dram_block_window, {kN0, 0}); + move_tile_window(k_dram_block_window, {k_advance, 0}); k_dram_window.set_window_origin(k_dram_block_window.get_window_origin()); // KV_BLOCKSCALE: reload physical pages for the new tile @@ -1427,6 +1485,21 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync k_physical_pages, stride_k, page_stride_k, k_coord, k_offsets, current_seq_k); k_dram_window.update_page_idx(k_offsets); rebase_k_window(k_dram_window, k_physical_pages[number<0>{}]); + + // After sink→window transition (i_total_loops == num_sink_loop), V window + // was advanced by kN0 (one normal iter), but current_seq_k jumped by k_advance + // = seqlen_k_start - sink_seq_end + kN0 > kN0. Re-init V to current_seq_k. + if constexpr(kHasSink) + { + if(i_total_loops == num_sink_loop && num_sink_loop > 0) + { + prefetch_v_physical_pages(number<0>{}); + update_v_offsets(number<0>{}); + v_dram_window.update_page_idx(v_offsets); + rebase_v_window(v_dram_window, v_physical_pages[number<0>{}]); + } + } + if constexpr(k1_loops >= 2 && LdsSeq.at(number<0>{}) == LdsSeq.at(number{})) __builtin_amdgcn_s_barrier(); diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 0670985e4f..7df39c3d11 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -53,6 +53,7 @@ template + kHasSink_> { static constexpr auto kKVMemoryLayout = kKVMemoryLayout_; static constexpr auto kKVLookupTable = kKVLookupTable_; From 7bcaa73a3a261a367895a10d789eb4f9c1a7b26c Mon Sep 17 00:00:00 2001 From: arai713 <67439843+arai713@users.noreply.github.com> Date: Tue, 21 Apr 2026 20:50:57 +0000 Subject: [PATCH 3/4] [rocm-libraries] ROCm/rocm-libraries#6537 (commit 16be4f7) [CK] Fix for hipblaslt error in PyTorch Dockerfile ## Motivation This PR fixes the hipblaslt client build failures that occur when building the PyTorch Docker image, which are currently causing failures in CI. ## Technical Details - Correctly reset the working directory to tmp/ - Added --use-system-packages to the install.sh to use system installed laplack packages, as hard-coded paths were not being built. ## Test Plan Locally built the Docker image using the Dockerfile. ## Test Result Image was successfully built. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- Dockerfile.pytorch | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Dockerfile.pytorch b/Dockerfile.pytorch index 2d3856fa2d..112197d207 100644 --- a/Dockerfile.pytorch +++ b/Dockerfile.pytorch @@ -22,6 +22,7 @@ RUN groupadd -g 109 render && \ chmod -R a+rwx /tmp/pytorch && \ sudo usermod -aG irc jenkins && \ #install hipblaslt + cd /tmp && \ git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && \ cd rocm-libraries && \ git checkout develop && \ @@ -29,4 +30,4 @@ RUN groupadd -g 109 render && \ git sparse-checkout set projects/hipblaslt shared/origami && \ cd projects/hipblaslt && \ git show --oneline -s && \ - CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx942;gfx950" -j 128 --skip_rocroller + CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --use-system-packages --architecture="gfx942;gfx950" -j 128 --skip_rocroller From 9d34174ac2473689d3a80bd8f1c2df86f6b4ff35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= <38502616+bartekxk@users.noreply.github.com> Date: Tue, 21 Apr 2026 21:50:07 +0000 Subject: [PATCH 4/4] [rocm-libraries] ROCm/rocm-libraries#5646 (commit 05680a4) [CK_TILE] Add conv bwd data tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation This PR adds tests for CK Tile's convolution backward data operation to enable functionality regression tracking and error-detection. ## Technical Details Currently only NHWGC/GKCYX/NHWGK and NDHWGC/GKCZYX/NDHWGK(2 dim and 3 dim channel-last) layouts are being tested, since only they are implemented in CK Tile. Current tests support FP16, BF16 and FP32 datatypes and various different convolutions scenarios. The tested instances are listed in `experimental/grouped_convolution_tile_instances` directory. ## Test Result All implemented tests are working properly and passing. --- .../ck_tile/builder/testing/conv/ck_tile.hpp | 13 +- .../generate_instances.py | 14 +- test/grouped_convnd_bwd_data/CMakeLists.txt | 11 + .../test_grouped_convnd_bwd_data_tile.cpp | 258 ++++++++++++++++++ 4 files changed, 280 insertions(+), 16 deletions(-) create mode 100644 test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_tile.cpp diff --git a/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp index 914c988d09..6eece48831 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp @@ -118,14 +118,11 @@ template ) { - if(kargs.k_batch > 1) - { - ck_tile::hip_check_error( - hipMemsetAsync(kargs.in_ptr, - 0, - zeroing_size * sizeof(typename Types::EDataType), - s_conf.stream_id_)); - } + ck_tile::hip_check_error( + hipMemsetAsync(kargs.in_ptr, + 0, + zeroing_size * sizeof(typename Types::EDataType), + s_conf.stream_id_)); } }; diff --git a/experimental/grouped_convolution_tile_instances/generate_instances.py b/experimental/grouped_convolution_tile_instances/generate_instances.py index 796e6b9158..b60e92c728 100755 --- a/experimental/grouped_convolution_tile_instances/generate_instances.py +++ b/experimental/grouped_convolution_tile_instances/generate_instances.py @@ -566,14 +566,12 @@ def parse_bwd_data_instances(instances, problem_name): if pipeline_version == "V6": print(f"Skipping instance {instance_id} with V6 since it's not supported yet.") continue - - # Check vector sizes for A and B tensors - we cannot oversubscribe. - num_tile_elements_a = m_per_xdl * k_per_xdl - num_tile_elements_b = n_per_xdl * k_per_xdl - max_vector_size_a = max(1, num_tile_elements_a // block_size) - max_vector_size_b = max(1, num_tile_elements_b // block_size) - a_scalar_per_vector = min(a_scalar_per_vector, max_vector_size_a) - b_scalar_per_vector = min(b_scalar_per_vector, max_vector_size_b) + if k_per_block > (warp_size * a_scalar_per_vector) or n_per_block > (warp_size * b_scalar_per_vector): + print(f"Skipping instance {instance_id} with multiple warps per continous tile dim since it's not supported yet.") + continue + if a_scalar_per_vector > (m_per_block * k_per_block) // block_size or b_scalar_per_vector > (n_per_block * k_per_block) // block_size: + print(f"Skipping instance {instance_id} because current scalar per vector exceedes tile size") + continue conv = ConvInstanceTemplateParams( spec, diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index 514f8e9668..7a318b4c19 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -22,6 +22,17 @@ if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") target_link_libraries(test_grouped_conv_bwd_data_scale PRIVATE gtest_main getopt::getopt utility device_grouped_conv3d_bwd_data_scale_instance) endif() +if(GPU_TARGETS MATCHES "gfx9") + if(CK_EXPERIMENTAL_BUILDER) + add_gtest_executable(test_grouped_convnd_bwd_data_tile test_grouped_convnd_bwd_data_tile.cpp) + target_compile_options(test_grouped_convnd_bwd_data_tile PRIVATE -Wno-global-constructors -Wno-undef -Wno-c++20-compat) + target_link_libraries(test_grouped_convnd_bwd_data_tile PRIVATE gtest_main getopt::getopt utility) + if(TARGET device_grouped_conv_bwd_data_tile_instances) + target_link_libraries(test_grouped_convnd_bwd_data_tile PRIVATE device_grouped_conv_bwd_data_tile_instances) + endif() + endif() +endif() + if (CK_USE_XDL OR CK_USE_WMMA) add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface.cpp) if(result EQUAL 0) diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_tile.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_tile.cpp new file mode 100644 index 0000000000..0b1c6e55f7 --- /dev/null +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_tile.cpp @@ -0,0 +1,258 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck_tile/builder/testing/conv/ck_tile.hpp" +#include "ck_tile/host/device_prop.hpp" +#include "profiler/grouped_convolution_backward_data_tile_algs.hpp" + +static ck::index_t args_mask = 0xffff; +static ck::index_t instance_index = -1; + +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace ckp = ck_tile::builder::profiling; + +template +struct SignatureDetails +{ + static constexpr ck_tile::index_t num_spatial_dim = num_spatial_dim_; + static constexpr ckb::DataType data_type = data_type_; + static constexpr ckb::DataType acc_data_type = acc_data_type_; + static constexpr ckb::TensorLayout in_layout = in_layout_; + static constexpr ckb::TensorLayout wei_layout = wei_layout_; + static constexpr ckb::TensorLayout out_layout = out_layout_; +}; + +template +class TestGroupedConvndBwdDataTile : public ::testing::Test +{ + protected: + static constexpr auto SIGNATURE = + ckt::ConvSignature{.spatial_dim = SignatureDetailsType::num_spatial_dim, + .direction = ckb::ConvDirection::BACKWARD_DATA, + .data_type = SignatureDetailsType::data_type, + .accumulation_data_type = SignatureDetailsType::acc_data_type, + .input = {.config = {.layout = SignatureDetailsType::in_layout}}, + .weight = {.config = {.layout = SignatureDetailsType::wei_layout}}, + .output = {.config = {.layout = SignatureDetailsType::out_layout}}}; + + std::vector> conv_args; + std::vector split_ks{"1", "2"}; + + template + void Run() + { + ASSERT_FALSE(conv_args.empty()); + bool pass = true; + for(size_t i = 0; i < conv_args.size(); i++) + { + for(auto& split_k : split_ks) + { + if((args_mask & (1 << i)) == 0) + { + continue; + } + auto& args = conv_args[i]; + + auto inputs = alloc_inputs(args); + auto outputs = alloc_outputs(args); + ckt::init_tensor_buffer_uniform_int( + inputs.get().weight, args.make_weight_descriptor(), -5, 5); + ckt::init_tensor_buffer_uniform_int( + inputs.get().output, args.make_output_descriptor(), -5, 5); + + HIP_CHECK_ERROR( + hipMemset(outputs.get().input, + 0, + args.make_input_descriptor().get_element_space_size_in_bytes())); + + std::cout << args.make_input_descriptor() << std::endl; + std::cout << args.make_weight_descriptor() << std::endl; + std::cout << args.make_output_descriptor() << std::endl; + [[maybe_unused]] auto&& [case_passed, + avg_time, + op_name, + best_split_k, + best_instance] = + + ckp::run_grouped_conv_backward_data_tile_algs( + args, + split_k, + -1, + inputs.get(), + outputs.get(), + ck_tile::stream_config{nullptr, false /*time_kernel*/}); + + pass = pass && case_passed; + } + } + EXPECT_TRUE(pass); + } + + void conv_args_append(std::size_t, + std::size_t G, + std::size_t N, + std::size_t K, + std::size_t C, + const std::vector& filter_spatial_lengths, + const std::vector& input_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) + { + ckt::Args args = { + .lengths = + { + .batch_size = N, + .groups = G, + .input_channels = C, + .output_channels = K, + .image = ckt::filter_extent_from_vector( + input_spatial_lengths), + .filter = ckt::filter_extent_from_vector( + filter_spatial_lengths), + }, + .filter_strides = ckt::filter_extent_from_vector( + conv_filter_strides), + .filter_dilation = + ckt::filter_extent_from_vector( + conv_filter_dilations), + .input_left_pad = ckt::filter_extent_from_vector( + input_left_pads), + .input_right_pad = + ckt::filter_extent_from_vector( + input_right_pads), + .a_elementwise_op = {}, + .b_elementwise_op = {}, + .cde_elementwise_op = {}, + }; + conv_args.push_back(args); + } +}; + +using KernelTypes2d = ::testing::Types, + SignatureDetails<2, + ckb::DataType::FP16, + ckb::DataType::FP32, + ckb::TensorLayout::NHWGC, + ckb::TensorLayout::GKYXC, + ckb::TensorLayout::NHWGK>, + SignatureDetails<2, + ckb::DataType::BF16, + ckb::DataType::FP32, + ckb::TensorLayout::NHWGC, + ckb::TensorLayout::GKYXC, + ckb::TensorLayout::NHWGK>>; + +using KernelTypes3d = ::testing::Types, + SignatureDetails<3, + ckb::DataType::FP16, + ckb::DataType::FP32, + ckb::TensorLayout::NDHWGC, + ckb::TensorLayout::GKZYXC, + ckb::TensorLayout::NDHWGK>, + SignatureDetails<3, + ckb::DataType::BF16, + ckb::DataType::FP32, + ckb::TensorLayout::NDHWGC, + ckb::TensorLayout::GKZYXC, + ckb::TensorLayout::NDHWGK>>; + +template +class TestGroupedConvndBwdDataTile2d : public TestGroupedConvndBwdDataTile +{ +}; + +template +class TestGroupedConvndBwdDataTile3d : public TestGroupedConvndBwdDataTile +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndBwdDataTile2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndBwdDataTile3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndBwdDataTile2d, Test2D) +{ + this->conv_args.clear(); + + // GroupedGemmGroupsNum = 4, ZTilde * YTilde * XTilde = 4, MaxGroupedGemmGroupsNum = 32 + this->conv_args_append(2, 2, 2, 16, 16, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}); + // GroupedGemmGroupsNum = 9, ZTilde * YTilde * XTilde = 36, MaxGroupedGemmGroupsNum = 32 + this->conv_args_append(2, 2, 2, 16, 16, {3, 3}, {28, 28}, {6, 6}, {1, 1}, {1, 1}, {1, 1}); + // GroupedGemmGroupsNum = 36, ZTilde * YTilde * XTilde = 36, MaxGroupedGemmGroupsNum = 32 + this->conv_args_append(2, 2, 2, 16, 16, {6, 6}, {28, 28}, {6, 6}, {1, 1}, {1, 1}, {1, 1}); + // GroupedGemmGroupsNum = 32, ZTilde * YTilde * XTilde = 32, MaxGroupedGemmGroupsNum = 32 + this->conv_args_append(2, 2, 2, 16, 16, {4, 8}, {28, 28}, {4, 8}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 2, 2, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 2, 2, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 2, 2, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 2, 2, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 2, 2, 32, 32, {2, 2}, {12, 12}, {3, 3}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 2, 2, 32, 32, {2, 2}, {12, 12}, {2, 2}, {2, 2}, {0, 0}, {0, 0}); + this->conv_args_append(2, 1, 6, 448, 896, {1, 1}, {118, 182}, {2, 2}, {1, 1}, {0, 0}, {0, 0}); + this->conv_args_append(2, 1, 1, 1, 32, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 1, 1, 64, 3, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->conv_args_append(2, 1, 1, 1, 1, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + this->template Run<2>(); +} + +TYPED_TEST(TestGroupedConvndBwdDataTile3d, Test3D) +{ + this->conv_args.clear(); + this->conv_args_append( + 3, 2, 2, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->conv_args_append( + 3, 2, 2, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 2, 2, 32, 32, {1, 2, 2}, {1, 12, 12}, {1, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 2, 2, 32, 32, {1, 2, 2}, {1, 12, 12}, {1, 2, 2}, {1, 2, 2}, {0, 0, 0}, {0, 0, 0}); + this->conv_args_append( + 3, 1, 1, 1, 32, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->conv_args_append( + 3, 1, 1, 64, 3, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->conv_args_append( + 3, 1, 1, 1, 1, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + this->template Run<3>(); +} + +int main(int argc, char** argv) +{ + testing::InitGoogleTest(&argc, argv); + if(argc == 1) {} + else if(argc == 3) + { + args_mask = strtol(argv[1], nullptr, 0); + instance_index = atoi(argv[2]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "Arg1,2: args_mask instance_index(-1 means all)" << std::endl; + } + return RUN_ALL_TESTS(); +}