From d1d457b82a63fdf6e68461194fa2c1098ace5f93 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Mon, 13 Apr 2026 03:34:08 -0400 Subject: [PATCH] Add sparge gpu pipeline in tile_example_sparge_vsa_sparse_attn --- example/ck_tile/50_sparse_attn/CMakeLists.txt | 34 +- .../50_sparse_attn/sparge_blockmap.cpp | 156 ++++++ .../ck_tile/50_sparse_attn/sparge_blockmap.h | 26 + .../50_sparse_attn/sparge_blockmap_inst.cpp | 88 +++ .../50_sparse_attn/sparge_blockmap_trek.hpp | 93 ++++ .../test_sparge_vsa_sparse_attn.cpp | 232 ++++++-- .../kernel/sparge_blockmap_kernel.hpp | 195 +++++++ .../pipeline/sparge_blockmap_pipeline.hpp | 521 ++++++++++++++++++ 8 files changed, 1295 insertions(+), 50 deletions(-) create mode 100644 example/ck_tile/50_sparse_attn/sparge_blockmap.cpp create mode 100644 example/ck_tile/50_sparse_attn/sparge_blockmap.h create mode 100644 example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp create mode 100644 example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp create mode 100644 include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp create mode 100644 include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index 0ac86f6aff..169ed87ac3 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -266,11 +266,41 @@ target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE -Wno-float-equal ) -# Sparge + VSA Example executable +# ============================================================================ +# Sparge BlockMap GPU Kernel (hand-written instantiation, no codegen) +# ============================================================================ +set(SPARGE_BLOCKMAP_INSTANCES "tile_sparge_blockmap_instances") + +add_library(${SPARGE_BLOCKMAP_INSTANCES} OBJECT EXCLUDE_FROM_ALL + ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp + ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp +) +target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE + ${CMAKE_CURRENT_LIST_DIR} + ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn +) +set_source_files_properties( + ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp + ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp + PROPERTIES LANGUAGE HIP +) +set_property(TARGET ${SPARGE_BLOCKMAP_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) + +target_compile_options(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE + -DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN + -DCK_TILE_FMHA_FWD_FAST_EXP2 + -Wno-undefined-func-template + -Wno-float-equal +) + +# Sparge + VSA Example executable (now links blockmap kernel too) set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn") message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}") add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp) -target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARGE_VSA_INSTANCES}) +target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} + ${SPARGE_VSA_INSTANCES} + ${SPARGE_BLOCKMAP_INSTANCES} +) target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE -Wno-undefined-func-template diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp new file mode 100644 index 0000000000..b9ac56c533 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp @@ -0,0 +1,156 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "sparge_blockmap.h" +#include "sparge_blockmap_trek.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/host/device_memory.hpp" +#include +#include + +template +sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + ck_tile::HostTensor& block_map_out, + int batch, + int nhead_q, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + bool i_perm, + float simthreshd1, + float cdfthreshd, + float topk, + int blkq, + int blkk, + int log_level) +{ + static_assert(std::is_same_v || + std::is_same_v, + "sparge_blockmap_gpu supports fp16/bf16 only."); + + std::string data_type = "fp16"; + if constexpr(std::is_same_v) + { + data_type = "bf16"; + } + + const ck_tile::index_t num_q_blocks = ck_tile::integer_divide_ceil(seqlen_q, blkq); + const ck_tile::index_t num_k_blocks = ck_tile::integer_divide_ceil(seqlen_k, blkk); + + const float scale = 1.0f / std::sqrt(static_cast(hdim_q)); + + // Allocate device memory + ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); + + const std::size_t bmap_bytes = + static_cast(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(uint8_t); + const std::size_t lut_bytes = + static_cast(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(int32_t); + const std::size_t valid_bytes = + static_cast(batch) * nhead_q * num_q_blocks * sizeof(int32_t); + + ck_tile::DeviceMem bmap_buf(bmap_bytes); + ck_tile::DeviceMem lut_buf(lut_bytes); + ck_tile::DeviceMem valid_buf(valid_bytes); + + q_buf.ToDevice(TQ.data()); + k_buf.ToDevice(TK.data()); + bmap_buf.SetZero(); + lut_buf.SetZero(); + valid_buf.SetZero(); + + // Compute strides (assumes BHSD if i_perm, BSHD otherwise) + const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead_q * hdim_q; + const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q; + const ck_tile::index_t nhead_stride_q = + i_perm ? static_cast(seqlen_q) * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_k = + i_perm ? static_cast(seqlen_k) * hdim_q : hdim_q; + const ck_tile::index_t batch_stride_q = + static_cast(nhead_q) * seqlen_q * hdim_q; + const ck_tile::index_t batch_stride_k = + static_cast(nhead_k) * seqlen_k * hdim_q; + + ck_tile::stream_config stream_config{nullptr, false, log_level, 0, 1, false}; + + sparge_blockmap_args args; + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.batch = batch; + args.seqlen_q = seqlen_q; + args.seqlen_k = seqlen_k; + args.hdim_q = hdim_q; + args.nhead_q = nhead_q; + args.nhead_k = nhead_k; + args.stride_q = stride_q; + args.stride_k = stride_k; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.simthreshd1 = simthreshd1; + args.cdfthreshd = cdfthreshd; + args.topk = topk; + args.scale = scale; + args.block_map_ptr = bmap_buf.GetDeviceBuffer(); + args.lut_ptr = lut_buf.GetDeviceBuffer(); + args.valid_block_num_ptr = valid_buf.GetDeviceBuffer(); + + sparge_blockmap_traits traits; + traits.data_type = data_type; + traits.hdim_q = hdim_q; + + sparge_blockmap_fwd(traits, args, stream_config); + + // Copy results back to host + bmap_buf.FromDevice(block_map_out.data(), bmap_bytes); + + sparge::VSALut vsa_lut{ + ck_tile::HostTensor({batch, nhead_q, num_q_blocks, num_k_blocks}), + ck_tile::HostTensor({batch, nhead_q, num_q_blocks}), + }; + lut_buf.FromDevice(vsa_lut.lut.data(), lut_bytes); + valid_buf.FromDevice(vsa_lut.valid_block_num.data(), valid_bytes); + + return vsa_lut; +} + +// Explicit template instantiations +template sparge::VSALut +sparge_blockmap_gpu(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + bool, + float, + float, + float, + int, + int, + int); + +template sparge::VSALut +sparge_blockmap_gpu(const ck_tile::HostTensor&, + const ck_tile::HostTensor&, + ck_tile::HostTensor&, + int, + int, + int, + int, + int, + int, + bool, + float, + float, + float, + int, + int, + int); diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap.h b/example/ck_tile/50_sparse_attn/sparge_blockmap.h new file mode 100644 index 0000000000..3057257ca1 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap.h @@ -0,0 +1,26 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include "sparge_tool.hpp" + +template +sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor& TQ, + const ck_tile::HostTensor& TK, + ck_tile::HostTensor& block_map_out, + int batch, + int nhead_q, + int nhead_k, + int seqlen_q, + int seqlen_k, + int hdim_q, + bool i_perm, + float simthreshd1, + float cdfthreshd, + float topk, + int blkq, + int blkk, + int log_level = 0); diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp new file mode 100644 index 0000000000..fbd18b9ff2 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +// Hand-written template instantiation for SpargeBlockMapKernel (fp16, D=128). + +#include "sparge_blockmap_trek.hpp" +#include "ck_tile/ops/fmha/block/variants.hpp" + +#include + +// ============================================================================ +// Type configuration for block map kernel (reuses FmhaSparseFwdTypeConfig) +// ============================================================================ + +// fp16: D=128, kM0=64, kN0=128 +using bmap_fp16_block_tile = ck_tile::sequence<64, 128, 128, 128, 128, 128>; +// kM0 kN0 kK0 kN1 kK1 kQKHeaddim(D) + +using bmap_fp16_shape = + ck_tile::TileFmhaShape, // Gemm0BlockWarps + ck_tile::sequence<16, 16, 16>, // Gemm0WarpTile (unused by blockmap, but + // needed by shape) + ck_tile::sequence<4, 1, 1>, // Gemm1BlockWarps + ck_tile::sequence<16, 16, 16>, // Gemm1WarpTile + true>; // VLayout row-major + +using bmap_fp16_trait = ck_tile::TileFmhaTraits; // kIsVRowMajorSkip + +using bmap_fp16_variant = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>; +using bmap_fp16_mask = ck_tile::GenericAttentionMask; + +using bmap_fp16_problem = ck_tile::BlockFmhaPipelineProblem; + +using bmap_fp16_pipeline = ck_tile::SpargeBlockMapPipeline; +using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel; + +// ============================================================================ +// Dispatch +// ============================================================================ + +float sparge_blockmap_fwd(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& s) +{ + if(traits.data_type == "fp16" && traits.hdim_q == 128) + { + using k_ = bmap_fp16_kernel; + if(s.log_level_ > 0) + std::cout << ", sparge_blockmap_fp16_d128" << std::flush; + auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); + const dim3 blocks = k_::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); + } + + if(s.log_level_ > 0) + std::cerr << "sparge_blockmap_fwd: unsupported config (data_type=" << traits.data_type + << ", hdim_q=" << traits.hdim_q << ")" << std::endl; + return -1.f; +} diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp new file mode 100644 index 0000000000..1e7e33248a --- /dev/null +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp @@ -0,0 +1,93 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" +#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" +#include "ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp" +#include "ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp" + +#include "fmha_fwd_trek.hpp" + +#include +#include + +// ============================================================================ +// Args and traits for sparge block map GPU kernel +// ============================================================================ +struct sparge_blockmap_args +{ + const void* q_ptr; + const void* k_ptr; + + ck_tile::index_t batch; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t hdim_q; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + + float simthreshd1; + float cdfthreshd; + float topk; + float scale; + + void* block_map_ptr; + void* lut_ptr; + void* valid_block_num_ptr; +}; + +struct sparge_blockmap_traits +{ + std::string data_type; + int hdim_q; +}; + +// ============================================================================ +// Create kernel args and grid dimensions +// ============================================================================ +template +auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = BlockMapKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.nhead_q, + args.nhead_q / args.nhead_k, + args.stride_q, + args.stride_k, + args.nhead_stride_q, + args.nhead_stride_k, + args.batch_stride_q, + args.batch_stride_k, + args.simthreshd1, + args.cdfthreshd, + args.topk, + args.scale, + args.block_map_ptr, + args.lut_ptr, + args.valid_block_num_ptr); + + dim3 grids = BlockMapKernel::GridSize(args.batch, args.nhead_q, args.seqlen_q); + return ck_tile::make_tuple(kargs, grids); +} + +// ============================================================================ +// Hand-written template instantiation dispatch +// ============================================================================ +float sparge_blockmap_fwd(sparge_blockmap_traits traits, + sparge_blockmap_args args, + const ck_tile::stream_config& stream_config); diff --git a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp index c0feb23e58..638a867b0f 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp @@ -17,6 +17,7 @@ #include "ck_tile/core/utility/bit_cast.hpp" #include "vsa_sparge_attention.h" +#include "sparge_blockmap.h" #include "sparge_tool.hpp" // ============================================================================ @@ -198,53 +199,37 @@ bool run_test(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor output_host = o_perm ? ck_tile::HostTensor({batch, nhead, seqlen_q, hdim_v}) : ck_tile::HostTensor({batch, seqlen_q, nhead, hdim_v}); - ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); std::cout << "\nInitializing tensors..." << std::endl; ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed}(q_host); ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 1}(k_host); ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); - // Build block map using Sparge tool - std::cout << "Building Sparge block map..." << std::endl; - sparge::SpargeParams p; - p.BLKQ = static_cast(BLKQ); - p.BLKK = static_cast(BLKK); - p.simthreshd1 = simthreshd1; - p.cdfthreshd = cdfthreshd; - p.topk = topk; - p.i_perm = i_perm; - - ck_tile::HostTensor block_relation_onehot = - sparge::build_block_map_meansim(q_host, k_host, p); - - // Convert to VSA LUT (delta-encoded) + valid_block_num - std::cout << "Converting block map to VSA LUT (delta)..." << std::endl; - auto vsa_lut = sparge::block_map_to_vsa_lut_delta(block_relation_onehot); - - // Print actual sparsity (based on one-hot) - std::size_t total_blocks = 0; - std::size_t active_blocks = 0; - for(ck_tile::index_t b = 0; b < batch; ++b) - { - for(ck_tile::index_t h = 0; h < nhead; ++h) - { - for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) - { - for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) - { - total_blocks++; - if(block_relation_onehot(b, h, qb, kb) != 0) - active_blocks++; - } - } - } - } - float actual_sparsity = - 1.0f - static_cast(active_blocks) / static_cast(total_blocks); - std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" - << total_blocks << " blocks active)" << std::endl; + // ================================================================== + // GPU: Build block map + VSA LUT in one kernel (always run) + // ================================================================== + std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl; + ck_tile::HostTensor block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks}); + auto vsa_lut_gpu = sparge_blockmap_gpu(q_host, + k_host, + block_map_gpu, + batch, + nhead, + nhead_k, + seqlen_q, + seqlen_k, + hdim_q, + i_perm, + simthreshd1, + cdfthreshd, + topk, + static_cast(BLKQ), + static_cast(BLKK), + 0); + // ================================================================== + // VSA sparse attention kernel (always run) + // ================================================================== std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl; try @@ -254,8 +239,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) vsa_sparge_attention(q_host, k_host, v_host, - vsa_lut.lut, - vsa_lut.valid_block_num, + vsa_lut_gpu.lut, + vsa_lut_gpu.valid_block_num, output_host, batch, nhead, @@ -276,8 +261,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) vsa_sparge_attention(q_host, k_host, v_host, - vsa_lut.lut, - vsa_lut.valid_block_num, + vsa_lut_gpu.lut, + vsa_lut_gpu.valid_block_num, output_host, batch, nhead, @@ -301,8 +286,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser) vsa_sparge_attention(q_host, k_host, v_host, - vsa_lut.lut, - vsa_lut.valid_block_num, + vsa_lut_gpu.lut, + vsa_lut_gpu.valid_block_num, output_host, batch, nhead, @@ -332,17 +317,168 @@ bool run_test(const ck_tile::ArgParser& arg_parser) return false; } + // ================================================================== + // Sparsity statistics (always run, pure CPU read of HostTensor) + // ================================================================== + std::size_t total_blocks = 0; + std::size_t active_blocks = 0; + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + total_blocks++; + if(block_map_gpu(b, h, qb, kb) != 0) + active_blocks++; + } + } + } + } + float actual_sparsity = + 1.0f - static_cast(active_blocks) / static_cast(total_blocks); + std::cout << "\n Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/" + << total_blocks << " blocks active)" << std::endl; + + // ================================================================== + // Validation (only when -v=1) + // ================================================================== bool pass = true; if(do_validation) { std::cout << "\n--- Performing CPU validation ---" << std::endl; + + // CPU golden: block map + VSA LUT + std::cout << "Building Sparge block map (CPU golden)..." << std::endl; + sparge::SpargeParams p; + p.BLKQ = static_cast(BLKQ); + p.BLKK = static_cast(BLKK); + p.simthreshd1 = simthreshd1; + p.cdfthreshd = cdfthreshd; + p.topk = topk; + p.i_perm = i_perm; + + ck_tile::HostTensor block_relation_onehot = + sparge::build_block_map_meansim(q_host, k_host, p); + + std::cout << "Converting block map to VSA LUT (delta, CPU)..." << std::endl; + auto vsa_lut_cpu = sparge::block_map_to_vsa_lut_delta(block_relation_onehot); + + // Validate block map + std::cout << "\n--- Validating GPU block map vs CPU golden ---" << std::endl; + { + std::size_t bmap_mismatches = 0; + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + if(block_map_gpu(b, h, qb, kb) != + block_relation_onehot(b, h, qb, kb)) + { + bmap_mismatches++; + if(bmap_mismatches <= 10) + { + std::cout + << " block_map mismatch at [" << b << "," << h << "," + << qb << "," << kb + << "]: GPU=" + << static_cast(block_map_gpu(b, h, qb, kb)) + << " CPU=" + << static_cast( + block_relation_onehot(b, h, qb, kb)) + << std::endl; + } + } + } + } + } + } + std::cout << " Block map mismatches: " << bmap_mismatches << " / " + << (batch * nhead * num_q_blocks * num_k_blocks) << std::endl; + if(bmap_mismatches > 0) + { + std::cout << ">>> GPU BLOCK MAP VALIDATION FAILED <<<" << std::endl; + pass = false; + } + else + { + std::cout << ">>> GPU BLOCK MAP VALIDATION PASSED <<<" << std::endl; + } + } + + // Validate VSA LUT + std::cout << "\n--- Validating GPU VSA LUT vs CPU golden ---" << std::endl; + { + std::size_t lut_mismatches = 0; + std::size_t valid_mismatches = 0; + for(ck_tile::index_t b = 0; b < batch; ++b) + { + for(ck_tile::index_t h = 0; h < nhead; ++h) + { + for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb) + { + if(vsa_lut_gpu.valid_block_num(b, h, qb) != + vsa_lut_cpu.valid_block_num(b, h, qb)) + { + valid_mismatches++; + if(valid_mismatches <= 5) + { + std::cout + << " valid_block_num mismatch at [" << b << "," << h + << "," << qb + << "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb) + << " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb) + << std::endl; + } + } + for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) + { + if(vsa_lut_gpu.lut(b, h, qb, kb) != + vsa_lut_cpu.lut(b, h, qb, kb)) + { + lut_mismatches++; + if(lut_mismatches <= 10) + { + std::cout + << " LUT mismatch at [" << b << "," << h << "," << qb + << "," << kb + << "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb) + << " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb) + << std::endl; + } + } + } + } + } + } + std::cout << " LUT mismatches: " << lut_mismatches << std::endl; + std::cout << " valid_block_num mismatches: " << valid_mismatches << std::endl; + if(lut_mismatches == 0 && valid_mismatches == 0) + { + std::cout << ">>> GPU VSA LUT VALIDATION PASSED <<<" << std::endl; + } + else + { + std::cout << ">>> GPU VSA LUT VALIDATION FAILED <<<" << std::endl; + pass = false; + } + } + + // Validate attention output float scale = 1.0f / std::sqrt(static_cast(hdim_q)); - std::cout << "Computing reference output..." << std::endl; + std::cout << "\nComputing reference attention output..." << std::endl; auto q_ref = to_bhsd(q_host, i_perm); auto k_ref = to_bhsd(k_host, i_perm); auto v_ref = to_bhsd(v_host, i_perm); + ck_tile::HostTensor output_ref({batch, nhead, seqlen_q, hdim_v}); ck_tile::reference_blocked_attention( q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale); @@ -374,7 +510,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) } } - std::cout << "\nValidation results:" << std::endl; + std::cout << "\nAttention validation results:" << std::endl; std::cout << " Max absolute difference: " << max_diff << std::endl; std::cout << " Max relative difference: " << max_rel_diff << std::endl; std::cout << " Number of mismatches: " << num_errors << " / " diff --git a/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp new file mode 100644 index 0000000000..ca177abf23 --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp @@ -0,0 +1,195 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include + +namespace ck_tile { + +template +struct SpargeBlockMapKernel +{ + using Pipeline = remove_cvref_t; + + static constexpr index_t kBlockSize = Pipeline::kBlockSize; + static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu; + + using QDataType = typename Pipeline::QDataType; + using KDataType = typename Pipeline::KDataType; + + static constexpr index_t kM0 = Pipeline::kM0; + static constexpr index_t kN0 = Pipeline::kN0; + static constexpr index_t D = Pipeline::D; + + static constexpr index_t kAlignment = 16 / sizeof(QDataType); + + struct Kargs + { + const void* q_ptr; + const void* k_ptr; + + index_t seqlen_q; + index_t seqlen_k; + index_t hdim_q; + + index_t nhead_q; + index_t nhead_ratio_qk; + + index_t stride_q; + index_t stride_k; + index_t nhead_stride_q; + index_t nhead_stride_k; + index_t batch_stride_q; + index_t batch_stride_k; + + float simthreshd1; + float cdfthreshd; + float topk; + float scale; + + void* block_map_ptr; + void* lut_ptr; + void* valid_block_num_ptr; + + index_t N_k; + }; + + CK_TILE_HOST static constexpr auto MakeKargs(const void* q_ptr, + const void* k_ptr, + index_t seqlen_q, + index_t seqlen_k, + index_t hdim_q, + index_t nhead_q, + index_t nhead_ratio_qk, + index_t stride_q, + index_t stride_k, + index_t nhead_stride_q, + index_t nhead_stride_k, + index_t batch_stride_q, + index_t batch_stride_k, + float simthreshd1, + float cdfthreshd, + float topk, + float scale, + void* block_map_ptr, + void* lut_ptr, + void* valid_block_num_ptr) + { + const index_t N_k = integer_divide_ceil(seqlen_k, kN0); + return Kargs{q_ptr, + k_ptr, + seqlen_q, + seqlen_k, + hdim_q, + nhead_q, + nhead_ratio_qk, + stride_q, + stride_k, + nhead_stride_q, + nhead_stride_k, + batch_stride_q, + batch_stride_k, + simthreshd1, + cdfthreshd, + topk, + scale, + block_map_ptr, + lut_ptr, + valid_block_num_ptr, + N_k}; + } + + CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_q, index_t seqlen_q) + { + const index_t Q_blk = integer_divide_ceil(seqlen_q, kM0); + return dim3(Q_blk, nhead_q, batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + const index_t qb = static_cast(blockIdx.x); + const index_t hq = static_cast(blockIdx.y); + const index_t b = static_cast(blockIdx.z); + + const index_t hk = hq / kargs.nhead_ratio_qk; + + // Q pointer for this (batch, head, q_block) + const auto* q_base = reinterpret_cast(kargs.q_ptr) + + b * kargs.batch_stride_q + hq * kargs.nhead_stride_q + + qb * kM0 * kargs.stride_q; + + // K pointer for this (batch, head_k) + const auto* k_base = reinterpret_cast(kargs.k_ptr) + + b * kargs.batch_stride_k + hk * kargs.nhead_stride_k; + + // Q DRAM view with OOB padding + const auto q_dram_naive = make_naive_tensor_view( + q_base, + make_tuple(kargs.seqlen_q - qb * kM0, D), + make_tuple(kargs.stride_q, 1), + number{}, + number<1>{}); + const auto q_dram = pad_tensor_view( + q_dram_naive, make_tuple(number{}, number{}), sequence{}); + + auto q_window = make_tile_window(q_dram, + make_tuple(number{}, number{}), + {0, 0}, + Pipeline::MakeQBlockDistribution()); + + // K DRAM view with OOB padding + const auto k_dram_naive = + make_naive_tensor_view(k_base, + make_tuple(kargs.seqlen_k, D), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + const auto k_dram = pad_tensor_view( + k_dram_naive, make_tuple(number{}, number{}), sequence{}); + + auto k_window = make_tile_window(k_dram, + make_tuple(number{}, number{}), + {0, 0}, + Pipeline::MakeKBlockDistribution()); + + // Output pointers for this (batch, head, q_block) + const index_t N_k = kargs.N_k; + const index_t bmap_offset = + (b * kargs.nhead_q + hq) * integer_divide_ceil(kargs.seqlen_q, kM0) * N_k + qb * N_k; + auto* bmap_ptr = reinterpret_cast(kargs.block_map_ptr) + bmap_offset; + + int32_t* lut_out = nullptr; + int32_t* valid_out = nullptr; + if(kargs.lut_ptr != nullptr) + { + lut_out = reinterpret_cast(kargs.lut_ptr) + bmap_offset; + const index_t valid_offset = + (b * kargs.nhead_q + hq) * integer_divide_ceil(kargs.seqlen_q, kM0) + qb; + valid_out = reinterpret_cast(kargs.valid_block_num_ptr) + valid_offset; + } + + // Shared memory + __shared__ char smem[Pipeline::GetSmemSize()]; + + Pipeline{}(q_window, + k_window, + kargs.seqlen_q, + kargs.seqlen_k, + qb, + N_k, + kargs.nhead_ratio_qk, + kargs.simthreshd1, + kargs.cdfthreshd, + kargs.topk, + kargs.scale, + bmap_ptr, + lut_out, + valid_out, + static_cast(smem)); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp b/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp new file mode 100644 index 0000000000..222e73c60e --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp @@ -0,0 +1,521 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce.hpp" + +namespace ck_tile { + +template +struct SpargeBlockMapPipeline +{ + using Problem = remove_cvref_t; + using QDataType = remove_cvref_t; + using KDataType = remove_cvref_t; + using BlockFmhaShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t D = BlockFmhaShape::kQKHeaddim; + static constexpr index_t NumWarps = BlockFmhaShape::NumWarps; + static constexpr index_t WarpSize = get_warp_size(); + + static constexpr index_t KPerThread = 16 / sizeof(QDataType); + static constexpr index_t KThreads = D / KPerThread; + static constexpr index_t SeqThreadPerWarp = WarpSize / KThreads; + static constexpr index_t MPerThread = kM0 / (SeqThreadPerWarp * NumWarps); + static constexpr index_t NPerThread = kN0 / (SeqThreadPerWarp * NumWarps); + + static constexpr index_t kBlockPerCu = 1; + static constexpr index_t kMaxKBlocks = 1024; + + // LDS layout (non-overlapping, all used simultaneously in Phase 2): + // [0 .. kReduceBytes) cross-warp reduction scratch + // [kScoreOffset ..) scores[N_k] + // [kBmapOffset ..) block_map[N_k] + // [kSmallOffset ..) Phase 3 argmax scratch (2*NumWarps floats) + static constexpr index_t kReduceBytes = NumWarps * D * sizeof(float); + static constexpr index_t kScoreOffset = kReduceBytes; + static constexpr index_t kBmapOffset = kScoreOffset + kMaxKBlocks * sizeof(float); + static constexpr index_t kSmallOffset = kBmapOffset + kMaxKBlocks * sizeof(uint8_t); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return kSmallOffset + 2 * NumWarps * sizeof(float); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeQBlockDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + CK_TILE_HOST_DEVICE static constexpr auto MakeKBlockDistribution() + { + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + // Extract tile data into a local float array via static_for (compile-time indices). + template + CK_TILE_DEVICE static void tile_to_float(const Tile& tile, float (&out)[BufSize]) + { + static_assert(Tile::get_thread_buffer_size() == BufSize); + const auto& buf = tile.get_thread_buffer(); + static_for<0, BufSize, 1>{}([&](auto i) { out[i.value] = type_convert(buf[i]); }); + } + + // Column-wise (dim=0) sum: accumulate SeqPerThread rows into KPerThread partial sums, + // then xor-shuffle across m_idx within warp. + template + CK_TILE_DEVICE static void column_reduce_thread_and_warp(const float* __restrict__ data, + float (&col_acc)[KPerThread]) + { + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] = 0.f; + + for(index_t m = 0; m < SeqPerThread; ++m) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += data[m * KPerThread + k]; + + for(index_t stride = KThreads; stride < WarpSize; stride *= 2) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += warp_shuffle(col_acc[k], __lane_id() ^ stride); + } + + // Cross-warp LDS reduction for column sums. + CK_TILE_DEVICE static void column_reduce_cross_warp(float (&col_acc)[KPerThread], + float* __restrict__ smem_reduce) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + const index_t k_idx = lane_id % KThreads; + const index_t m_idx = lane_id / KThreads; + + if(m_idx == 0) + for(index_t k = 0; k < KPerThread; ++k) + smem_reduce[warp_id * D + k_idx * KPerThread + k] = col_acc[k]; + __syncthreads(); + + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] = 0.f; + for(index_t w = 0; w < NumWarps; ++w) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += smem_reduce[w * D + k_idx * KPerThread + k]; + __syncthreads(); + } + + // Compute ||v||^2 per row: sum along KPerThread then xor-shuffle across k_idx. + template + CK_TILE_DEVICE static void row_reduce_sq_norm(const float* __restrict__ data, + float (&row_norms)[SeqPerThread], + index_t actual_seq) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t m_idx = (tid % WarpSize) / KThreads; + + for(index_t m = 0; m < SeqPerThread; ++m) + { + float sq = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + { + float v = data[m * KPerThread + k]; + sq += v * v; + } + for(index_t stride = 1; stride < KThreads; stride *= 2) + sq += warp_shuffle(sq, __lane_id() ^ stride); + + index_t gsq = m * (SeqThreadPerWarp * NumWarps) + warp_id * SeqThreadPerWarp + m_idx; + row_norms[m] = (gsq < actual_seq) ? sq : 0.f; + } + } + + // Column reduce of normalised rows: sum_hat[d] = sum_i data[i,d] / ||data[i,:]||. + template + CK_TILE_DEVICE static void column_reduce_normalised(const float* __restrict__ data, + const float* __restrict__ row_norms, + float (&col_acc)[KPerThread], + index_t actual_seq) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t m_idx = (tid % WarpSize) / KThreads; + + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] = 0.f; + + for(index_t m = 0; m < SeqPerThread; ++m) + { + float inv_norm = (row_norms[m] > 0.f) ? (1.0f / __builtin_sqrtf(row_norms[m])) : 0.f; + index_t gsq = m * (SeqThreadPerWarp * NumWarps) + warp_id * SeqThreadPerWarp + m_idx; + if(gsq < actual_seq) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += data[m * KPerThread + k] * inv_norm; + } + + for(index_t stride = KThreads; stride < WarpSize; stride *= 2) + for(index_t k = 0; k < KPerThread; ++k) + col_acc[k] += warp_shuffle(col_acc[k], __lane_id() ^ stride); + } + + // Scalar reduce across k_idx lanes (within warp). + CK_TILE_DEVICE static float reduce_across_k(float v) + { + for(index_t stride = 1; stride < KThreads; stride *= 2) + v += warp_shuffle(v, __lane_id() ^ stride); + return v; + } + + // Full-block scalar reduce (warp xor + cross-warp LDS). + CK_TILE_DEVICE static float block_reduce_sum(float v, float* smem_small) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + + for(index_t stride = 1; stride < WarpSize; stride *= 2) + v += warp_shuffle(v, __lane_id() ^ stride); + if(lane_id == 0) + smem_small[warp_id] = v; + __syncthreads(); + if(tid == 0) + { + float s = 0.f; + for(index_t w = 0; w < NumWarps; ++w) + s += smem_small[w]; + smem_small[0] = s; + } + __syncthreads(); + return smem_small[0]; + } + + CK_TILE_DEVICE static float block_reduce_max(float v, float* smem_small) + { + const index_t tid = static_cast(threadIdx.x); + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + + for(index_t stride = 1; stride < WarpSize; stride *= 2) + v = max(v, warp_shuffle(v, __lane_id() ^ stride)); + if(lane_id == 0) + smem_small[warp_id] = v; + __syncthreads(); + if(tid == 0) + { + float s = smem_small[0]; + for(index_t w = 1; w < NumWarps; ++w) + s = max(s, smem_small[w]); + smem_small[0] = s; + } + __syncthreads(); + return smem_small[0]; + } + + // ====================================================================== + template + CK_TILE_DEVICE void operator()(const QWindowType& q_window_in, + const KWindowType& k_window_in, + index_t seqlen_q, + index_t seqlen_k, + index_t qb, + index_t N_k, + index_t /*nhead_ratio_qk*/, + float simthreshd1, + float cdfthreshd, + float topk, + float scale, + uint8_t* block_map_ptr, + int32_t* lut_ptr, + int32_t* valid_block_num_ptr, + void* smem_ptr) const + { + const index_t tid = static_cast(threadIdx.x); + + auto* smem_float = reinterpret_cast(smem_ptr); + auto* smem_scores = + reinterpret_cast(reinterpret_cast(smem_ptr) + kScoreOffset); + auto* smem_bmap = + reinterpret_cast(reinterpret_cast(smem_ptr) + kBmapOffset); + auto* smem_small = + reinterpret_cast(reinterpret_cast(smem_ptr) + kSmallOffset); + + const index_t bs_q = min(static_cast(kM0), seqlen_q - qb * kM0); + const float inv_bs_q = (bs_q > 0) ? (1.0f / static_cast(bs_q)) : 0.f; + + // ================================================================== + // Phase 1: Q Block Statistics + // ================================================================== + auto q_tile = load_tile(q_window_in); + + float q_data[MPerThread * KPerThread]; + tile_to_float(q_tile, q_data); + + // 1a. L2 norm per token + float psq[MPerThread]; + row_reduce_sq_norm(q_data, psq, bs_q); + + // 1b. Column sum -> mean + float pooled_q_mean[KPerThread]; + column_reduce_thread_and_warp(q_data, pooled_q_mean); + column_reduce_cross_warp(pooled_q_mean, smem_float); + for(index_t k = 0; k < KPerThread; ++k) + pooled_q_mean[k] *= inv_bs_q; + + // 1c. Normalised sum_hat + float sum_hat[KPerThread]; + column_reduce_normalised(q_data, psq, sum_hat, bs_q); + column_reduce_cross_warp(sum_hat, smem_float); + + // 1d. sim_q = ||sum_hat||^2 / bs_q^2 + float sh_sq = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + sh_sq += sum_hat[k] * sum_hat[k]; + sh_sq = reduce_across_k(sh_sq); + const float denom_q = static_cast(bs_q) * static_cast(bs_q); + const bool sim_q = (denom_q > 0.f) && ((sh_sq / denom_q) > simthreshd1); + + // Not similar → force all K blocks ON, early exit + if(!sim_q) + { + for(index_t i = tid; i < N_k; i += kBlockSize) + block_map_ptr[i] = 1; + + if(lut_ptr != nullptr && tid == 0) + { + int32_t valid = 0, prev = 0; + for(index_t kb = 0; kb < N_k; ++kb) + { + lut_ptr[valid] = static_cast(kb) - prev; + prev = static_cast(kb); + ++valid; + } + for(index_t i = valid; i < N_k; ++i) + lut_ptr[i] = 0; + *valid_block_num_ptr = valid; + } + return; + } + + // ================================================================== + // Phase 2: K Block Loop + // ================================================================== + for(index_t i = tid; i < N_k; i += kBlockSize) + smem_bmap[i] = 0; + __syncthreads(); + + auto k_window = k_window_in; + + for(index_t kb = 0; kb < N_k; ++kb) + { + const index_t bs_k = min(static_cast(kN0), seqlen_k - kb * kN0); + const float inv_bs_k = (bs_k > 0) ? (1.0f / static_cast(bs_k)) : 0.f; + + auto k_tile = load_tile(k_window); + + float k_data[NPerThread * KPerThread]; + tile_to_float(k_tile, k_data); + + // K mean + float pooled_k_mean[KPerThread]; + column_reduce_thread_and_warp(k_data, pooled_k_mean); + column_reduce_cross_warp(pooled_k_mean, smem_float); + for(index_t k = 0; k < KPerThread; ++k) + pooled_k_mean[k] *= inv_bs_k; + + // dot(pooled_q_mean, pooled_k_mean) + float dot = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + dot += pooled_q_mean[k] * pooled_k_mean[k]; + dot = reduce_across_k(dot); + + // K L2 norms + normalised sum_hat + float k_psq[NPerThread]; + row_reduce_sq_norm(k_data, k_psq, bs_k); + + float k_sum_hat[KPerThread]; + column_reduce_normalised(k_data, k_psq, k_sum_hat, bs_k); + column_reduce_cross_warp(k_sum_hat, smem_float); + + // sim_k + float ksh_sq = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + ksh_sq += k_sum_hat[k] * k_sum_hat[k]; + ksh_sq = reduce_across_k(ksh_sq); + const float denom_k = static_cast(bs_k) * static_cast(bs_k); + const bool sim_k = (denom_k > 0.f) && ((ksh_sq / denom_k) > simthreshd1); + + if(tid == 0) + { + if(!sim_k) + { + smem_bmap[kb] = 1; + smem_scores[kb] = -numeric::infinity(); + } + else + { + smem_scores[kb] = dot * scale; + } + } + __syncthreads(); + + move_tile_window(k_window, {kN0, 0}); + } + + // ================================================================== + // Phase 3: Softmax + Selection + // ================================================================== + + // max + float lmax = -numeric::infinity(); + for(index_t i = tid; i < N_k; i += kBlockSize) + lmax = max(lmax, smem_scores[i]); + const float max_score = block_reduce_max(lmax, smem_small); + + // exp + sum + float lsum = 0.f; + for(index_t i = tid; i < N_k; i += kBlockSize) + { + float e = (smem_scores[i] > -numeric::infinity()) + ? __builtin_expf(smem_scores[i] - max_score) + : 0.f; + smem_scores[i] = e; + lsum += e; + } + const float sum_exp = block_reduce_sum(lsum, smem_small); + + // normalise + const float inv_sum = (sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f; + for(index_t i = tid; i < N_k; i += kBlockSize) + smem_scores[i] *= inv_sum; + __syncthreads(); + + // Selection: iterative argmax + index_t num_to_select = + (topk > 0.f) + ? max(static_cast(1), static_cast(topk * static_cast(N_k))) + : N_k; + + float cumulative_prob = 0.f; + for(index_t round = 0; round < num_to_select; ++round) + { + // thread-local argmax + float best_val = -1.f; + index_t best_idx = 0; + for(index_t i = tid; i < N_k; i += kBlockSize) + { + if(smem_scores[i] > best_val || (smem_scores[i] == best_val && i < best_idx)) + { + best_val = smem_scores[i]; + best_idx = i; + } + } + + // warp argmax + for(index_t stride = 1; stride < WarpSize; stride *= 2) + { + float rv = warp_shuffle(best_val, __lane_id() ^ stride); + index_t ri = warp_shuffle(best_idx, __lane_id() ^ stride); + if(rv > best_val || (rv == best_val && ri < best_idx)) + { + best_val = rv; + best_idx = ri; + } + } + + // cross-warp argmax via LDS + const index_t lane_id = tid % WarpSize; + const index_t warp_id = tid / WarpSize; + if(lane_id == 0) + { + smem_small[warp_id] = best_val; + smem_small[NumWarps + warp_id] = bit_cast(static_cast(best_idx)); + } + __syncthreads(); + + if(tid == 0) + { + float bv = smem_small[0]; + index_t bi = bit_cast(smem_small[NumWarps]); + for(index_t w = 1; w < NumWarps; ++w) + { + float wv = smem_small[w]; + index_t wi = bit_cast(smem_small[NumWarps + w]); + if(wv > bv || (wv == bv && wi < bi)) + { + bv = wv; + bi = wi; + } + } + smem_small[0] = bv; + smem_small[1] = bit_cast(static_cast(bi)); + } + __syncthreads(); + + float g_val = smem_small[0]; + index_t g_idx = bit_cast(smem_small[1]); + + if(g_val <= 0.f) + break; + + if(tid == 0) + { + smem_bmap[g_idx] = 1; + smem_scores[g_idx] = -1.f; + } + __syncthreads(); + + if(topk > 0.f) + { + if(round + 1 >= num_to_select) + break; + } + else + { + cumulative_prob += g_val; + if(cumulative_prob >= cdfthreshd) + break; + } + } + + // ================================================================== + // Write outputs to global memory + // ================================================================== + for(index_t i = tid; i < N_k; i += kBlockSize) + block_map_ptr[i] = smem_bmap[i]; + + if(lut_ptr != nullptr && tid == 0) + { + int32_t valid = 0, prev = 0; + for(index_t kb = 0; kb < N_k; ++kb) + { + if(smem_bmap[kb] != 0) + { + lut_ptr[valid] = static_cast(kb) - prev; + prev = static_cast(kb); + ++valid; + } + } + for(index_t i = valid; i < N_k; ++i) + lut_ptr[i] = 0; + *valid_block_num_ptr = valid; + } + } +}; + +} // namespace ck_tile