mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add sparge gpu pipeline in tile_example_sparge_vsa_sparse_attn
This commit is contained in:
@@ -266,11 +266,41 @@ target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# Sparge + VSA Example executable
|
||||
# ============================================================================
|
||||
# Sparge BlockMap GPU Kernel (hand-written instantiation, no codegen)
|
||||
# ============================================================================
|
||||
set(SPARGE_BLOCKMAP_INSTANCES "tile_sparge_blockmap_instances")
|
||||
|
||||
add_library(${SPARGE_BLOCKMAP_INSTANCES} OBJECT EXCLUDE_FROM_ALL
|
||||
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp
|
||||
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp
|
||||
)
|
||||
target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
|
||||
)
|
||||
set_source_files_properties(
|
||||
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp
|
||||
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp
|
||||
PROPERTIES LANGUAGE HIP
|
||||
)
|
||||
set_property(TARGET ${SPARGE_BLOCKMAP_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
target_compile_options(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE
|
||||
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
|
||||
-DCK_TILE_FMHA_FWD_FAST_EXP2
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
)
|
||||
|
||||
# Sparge + VSA Example executable (now links blockmap kernel too)
|
||||
set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn")
|
||||
message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}")
|
||||
add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp)
|
||||
target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARGE_VSA_INSTANCES})
|
||||
target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}
|
||||
${SPARGE_VSA_INSTANCES}
|
||||
${SPARGE_BLOCKMAP_INSTANCES}
|
||||
)
|
||||
target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
|
||||
156
example/ck_tile/50_sparse_attn/sparge_blockmap.cpp
Normal file
156
example/ck_tile/50_sparse_attn/sparge_blockmap.cpp
Normal file
@@ -0,0 +1,156 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "sparge_blockmap.h"
|
||||
#include "sparge_blockmap_trek.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include <type_traits>
|
||||
#include <cmath>
|
||||
|
||||
template <typename DataType_>
|
||||
sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
ck_tile::HostTensor<uint8_t>& block_map_out,
|
||||
int batch,
|
||||
int nhead_q,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
bool i_perm,
|
||||
float simthreshd1,
|
||||
float cdfthreshd,
|
||||
float topk,
|
||||
int blkq,
|
||||
int blkk,
|
||||
int log_level)
|
||||
{
|
||||
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
|
||||
std::is_same_v<DataType_, ck_tile::bf16_t>,
|
||||
"sparge_blockmap_gpu supports fp16/bf16 only.");
|
||||
|
||||
std::string data_type = "fp16";
|
||||
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
|
||||
{
|
||||
data_type = "bf16";
|
||||
}
|
||||
|
||||
const ck_tile::index_t num_q_blocks = ck_tile::integer_divide_ceil(seqlen_q, blkq);
|
||||
const ck_tile::index_t num_k_blocks = ck_tile::integer_divide_ceil(seqlen_k, blkk);
|
||||
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
|
||||
|
||||
const std::size_t bmap_bytes =
|
||||
static_cast<std::size_t>(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(uint8_t);
|
||||
const std::size_t lut_bytes =
|
||||
static_cast<std::size_t>(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(int32_t);
|
||||
const std::size_t valid_bytes =
|
||||
static_cast<std::size_t>(batch) * nhead_q * num_q_blocks * sizeof(int32_t);
|
||||
|
||||
ck_tile::DeviceMem bmap_buf(bmap_bytes);
|
||||
ck_tile::DeviceMem lut_buf(lut_bytes);
|
||||
ck_tile::DeviceMem valid_buf(valid_bytes);
|
||||
|
||||
q_buf.ToDevice(TQ.data());
|
||||
k_buf.ToDevice(TK.data());
|
||||
bmap_buf.SetZero();
|
||||
lut_buf.SetZero();
|
||||
valid_buf.SetZero();
|
||||
|
||||
// Compute strides (assumes BHSD if i_perm, BSHD otherwise)
|
||||
const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead_q * hdim_q;
|
||||
const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q;
|
||||
const ck_tile::index_t nhead_stride_q =
|
||||
i_perm ? static_cast<ck_tile::index_t>(seqlen_q) * hdim_q : hdim_q;
|
||||
const ck_tile::index_t nhead_stride_k =
|
||||
i_perm ? static_cast<ck_tile::index_t>(seqlen_k) * hdim_q : hdim_q;
|
||||
const ck_tile::index_t batch_stride_q =
|
||||
static_cast<ck_tile::index_t>(nhead_q) * seqlen_q * hdim_q;
|
||||
const ck_tile::index_t batch_stride_k =
|
||||
static_cast<ck_tile::index_t>(nhead_k) * seqlen_k * hdim_q;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr, false, log_level, 0, 1, false};
|
||||
|
||||
sparge_blockmap_args args;
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.batch = batch;
|
||||
args.seqlen_q = seqlen_q;
|
||||
args.seqlen_k = seqlen_k;
|
||||
args.hdim_q = hdim_q;
|
||||
args.nhead_q = nhead_q;
|
||||
args.nhead_k = nhead_k;
|
||||
args.stride_q = stride_q;
|
||||
args.stride_k = stride_k;
|
||||
args.nhead_stride_q = nhead_stride_q;
|
||||
args.nhead_stride_k = nhead_stride_k;
|
||||
args.batch_stride_q = batch_stride_q;
|
||||
args.batch_stride_k = batch_stride_k;
|
||||
args.simthreshd1 = simthreshd1;
|
||||
args.cdfthreshd = cdfthreshd;
|
||||
args.topk = topk;
|
||||
args.scale = scale;
|
||||
args.block_map_ptr = bmap_buf.GetDeviceBuffer();
|
||||
args.lut_ptr = lut_buf.GetDeviceBuffer();
|
||||
args.valid_block_num_ptr = valid_buf.GetDeviceBuffer();
|
||||
|
||||
sparge_blockmap_traits traits;
|
||||
traits.data_type = data_type;
|
||||
traits.hdim_q = hdim_q;
|
||||
|
||||
sparge_blockmap_fwd(traits, args, stream_config);
|
||||
|
||||
// Copy results back to host
|
||||
bmap_buf.FromDevice(block_map_out.data(), bmap_bytes);
|
||||
|
||||
sparge::VSALut vsa_lut{
|
||||
ck_tile::HostTensor<int32_t>({batch, nhead_q, num_q_blocks, num_k_blocks}),
|
||||
ck_tile::HostTensor<int32_t>({batch, nhead_q, num_q_blocks}),
|
||||
};
|
||||
lut_buf.FromDevice(vsa_lut.lut.data(), lut_bytes);
|
||||
valid_buf.FromDevice(vsa_lut.valid_block_num.data(), valid_bytes);
|
||||
|
||||
return vsa_lut;
|
||||
}
|
||||
|
||||
// Explicit template instantiations
|
||||
template sparge::VSALut
|
||||
sparge_blockmap_gpu<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
ck_tile::HostTensor<uint8_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
|
||||
template sparge::VSALut
|
||||
sparge_blockmap_gpu<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
ck_tile::HostTensor<uint8_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
26
example/ck_tile/50_sparse_attn/sparge_blockmap.h
Normal file
26
example/ck_tile/50_sparse_attn/sparge_blockmap.h
Normal file
@@ -0,0 +1,26 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "sparge_tool.hpp"
|
||||
|
||||
template <typename DataType_>
|
||||
sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
ck_tile::HostTensor<uint8_t>& block_map_out,
|
||||
int batch,
|
||||
int nhead_q,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
bool i_perm,
|
||||
float simthreshd1,
|
||||
float cdfthreshd,
|
||||
float topk,
|
||||
int blkq,
|
||||
int blkk,
|
||||
int log_level = 0);
|
||||
88
example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp
Normal file
88
example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp
Normal file
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Hand-written template instantiation for SpargeBlockMapKernel (fp16, D=128).
|
||||
|
||||
#include "sparge_blockmap_trek.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
// ============================================================================
|
||||
// Type configuration for block map kernel (reuses FmhaSparseFwdTypeConfig)
|
||||
// ============================================================================
|
||||
|
||||
// fp16: D=128, kM0=64, kN0=128
|
||||
using bmap_fp16_block_tile = ck_tile::sequence<64, 128, 128, 128, 128, 128>;
|
||||
// kM0 kN0 kK0 kN1 kK1 kQKHeaddim(D)
|
||||
|
||||
using bmap_fp16_shape =
|
||||
ck_tile::TileFmhaShape<bmap_fp16_block_tile,
|
||||
ck_tile::sequence<4, 1, 1>, // Gemm0BlockWarps
|
||||
ck_tile::sequence<16, 16, 16>, // Gemm0WarpTile (unused by blockmap, but
|
||||
// needed by shape)
|
||||
ck_tile::sequence<4, 1, 1>, // Gemm1BlockWarps
|
||||
ck_tile::sequence<16, 16, 16>, // Gemm1WarpTile
|
||||
true>; // VLayout row-major
|
||||
|
||||
using bmap_fp16_trait = ck_tile::TileFmhaTraits<true, // kPadSeqLenQ
|
||||
true, // kPadSeqLenK
|
||||
true, // kPadHeadDimQ
|
||||
true, // kPadHeadDimV
|
||||
false, // kHasLogitsSoftCap
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false, // kStoreLSE
|
||||
false, // kHasDropout
|
||||
false, // kHasRandVal
|
||||
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE,
|
||||
-1, // kBlockPerCu
|
||||
false>; // kIsVRowMajorSkip
|
||||
|
||||
using bmap_fp16_variant = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
using bmap_fp16_mask = ck_tile::GenericAttentionMask<false>;
|
||||
|
||||
using bmap_fp16_problem = ck_tile::BlockFmhaPipelineProblem<ck_tile::half_t, // QDataType
|
||||
ck_tile::half_t, // KDataType
|
||||
ck_tile::half_t, // VDataType
|
||||
float, // SaccDataType
|
||||
float, // SMPLComputeDataType
|
||||
ck_tile::half_t, // BiasDataType
|
||||
uint8_t, // RandValOutputDataType
|
||||
float, // LSEDataType
|
||||
ck_tile::half_t, // PDataType
|
||||
float, // OaccDataType
|
||||
ck_tile::half_t, // ODataType
|
||||
bmap_fp16_shape,
|
||||
false, // kIsGroupMode
|
||||
bmap_fp16_variant,
|
||||
bmap_fp16_mask,
|
||||
false, // kUseTrLoad
|
||||
bmap_fp16_trait>;
|
||||
|
||||
using bmap_fp16_pipeline = ck_tile::SpargeBlockMapPipeline<bmap_fp16_problem>;
|
||||
using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel<bmap_fp16_pipeline>;
|
||||
|
||||
// ============================================================================
|
||||
// Dispatch
|
||||
// ============================================================================
|
||||
|
||||
float sparge_blockmap_fwd(sparge_blockmap_traits traits,
|
||||
sparge_blockmap_args args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(traits.data_type == "fp16" && traits.hdim_q == 128)
|
||||
{
|
||||
using k_ = bmap_fp16_kernel;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", sparge_blockmap_fp16_d128" << std::flush;
|
||||
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<k_>(args);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
std::cerr << "sparge_blockmap_fwd: unsupported config (data_type=" << traits.data_type
|
||||
<< ", hdim_q=" << traits.hdim_q << ")" << std::endl;
|
||||
return -1.f;
|
||||
}
|
||||
93
example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp
Normal file
93
example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp
Normal file
@@ -0,0 +1,93 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
||||
#include "ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp"
|
||||
#include "ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp"
|
||||
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
// ============================================================================
|
||||
// Args and traits for sparge block map GPU kernel
|
||||
// ============================================================================
|
||||
struct sparge_blockmap_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_k;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
|
||||
float simthreshd1;
|
||||
float cdfthreshd;
|
||||
float topk;
|
||||
float scale;
|
||||
|
||||
void* block_map_ptr;
|
||||
void* lut_ptr;
|
||||
void* valid_block_num_ptr;
|
||||
};
|
||||
|
||||
struct sparge_blockmap_traits
|
||||
{
|
||||
std::string data_type;
|
||||
int hdim_q;
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Create kernel args and grid dimensions
|
||||
// ============================================================================
|
||||
template <typename BlockMapKernel>
|
||||
auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = BlockMapKernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.nhead_q,
|
||||
args.nhead_q / args.nhead_k,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.nhead_stride_q,
|
||||
args.nhead_stride_k,
|
||||
args.batch_stride_q,
|
||||
args.batch_stride_k,
|
||||
args.simthreshd1,
|
||||
args.cdfthreshd,
|
||||
args.topk,
|
||||
args.scale,
|
||||
args.block_map_ptr,
|
||||
args.lut_ptr,
|
||||
args.valid_block_num_ptr);
|
||||
|
||||
dim3 grids = BlockMapKernel::GridSize(args.batch, args.nhead_q, args.seqlen_q);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Hand-written template instantiation dispatch
|
||||
// ============================================================================
|
||||
float sparge_blockmap_fwd(sparge_blockmap_traits traits,
|
||||
sparge_blockmap_args args,
|
||||
const ck_tile::stream_config& stream_config);
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include "vsa_sparge_attention.h"
|
||||
#include "sparge_blockmap.h"
|
||||
#include "sparge_tool.hpp"
|
||||
|
||||
// ============================================================================
|
||||
@@ -198,53 +199,37 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::HostTensor<T> output_host =
|
||||
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
|
||||
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
|
||||
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
|
||||
|
||||
std::cout << "\nInitializing tensors..." << std::endl;
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
|
||||
|
||||
// Build block map using Sparge tool
|
||||
std::cout << "Building Sparge block map..." << std::endl;
|
||||
sparge::SpargeParams p;
|
||||
p.BLKQ = static_cast<int>(BLKQ);
|
||||
p.BLKK = static_cast<int>(BLKK);
|
||||
p.simthreshd1 = simthreshd1;
|
||||
p.cdfthreshd = cdfthreshd;
|
||||
p.topk = topk;
|
||||
p.i_perm = i_perm;
|
||||
|
||||
ck_tile::HostTensor<uint8_t> block_relation_onehot =
|
||||
sparge::build_block_map_meansim(q_host, k_host, p);
|
||||
|
||||
// Convert to VSA LUT (delta-encoded) + valid_block_num
|
||||
std::cout << "Converting block map to VSA LUT (delta)..." << std::endl;
|
||||
auto vsa_lut = sparge::block_map_to_vsa_lut_delta(block_relation_onehot);
|
||||
|
||||
// Print actual sparsity (based on one-hot)
|
||||
std::size_t total_blocks = 0;
|
||||
std::size_t active_blocks = 0;
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
|
||||
{
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
total_blocks++;
|
||||
if(block_relation_onehot(b, h, qb, kb) != 0)
|
||||
active_blocks++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
float actual_sparsity =
|
||||
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
|
||||
std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
|
||||
<< total_blocks << " blocks active)" << std::endl;
|
||||
// ==================================================================
|
||||
// GPU: Build block map + VSA LUT in one kernel (always run)
|
||||
// ==================================================================
|
||||
std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl;
|
||||
ck_tile::HostTensor<uint8_t> block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks});
|
||||
auto vsa_lut_gpu = sparge_blockmap_gpu<T>(q_host,
|
||||
k_host,
|
||||
block_map_gpu,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
i_perm,
|
||||
simthreshd1,
|
||||
cdfthreshd,
|
||||
topk,
|
||||
static_cast<int>(BLKQ),
|
||||
static_cast<int>(BLKK),
|
||||
0);
|
||||
|
||||
// ==================================================================
|
||||
// VSA sparse attention kernel (always run)
|
||||
// ==================================================================
|
||||
std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl;
|
||||
|
||||
try
|
||||
@@ -254,8 +239,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
vsa_sparge_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
vsa_lut.lut,
|
||||
vsa_lut.valid_block_num,
|
||||
vsa_lut_gpu.lut,
|
||||
vsa_lut_gpu.valid_block_num,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
@@ -276,8 +261,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
vsa_sparge_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
vsa_lut.lut,
|
||||
vsa_lut.valid_block_num,
|
||||
vsa_lut_gpu.lut,
|
||||
vsa_lut_gpu.valid_block_num,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
@@ -301,8 +286,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
vsa_sparge_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
vsa_lut.lut,
|
||||
vsa_lut.valid_block_num,
|
||||
vsa_lut_gpu.lut,
|
||||
vsa_lut_gpu.valid_block_num,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
@@ -332,17 +317,168 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
return false;
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// Sparsity statistics (always run, pure CPU read of HostTensor)
|
||||
// ==================================================================
|
||||
std::size_t total_blocks = 0;
|
||||
std::size_t active_blocks = 0;
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
|
||||
{
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
total_blocks++;
|
||||
if(block_map_gpu(b, h, qb, kb) != 0)
|
||||
active_blocks++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
float actual_sparsity =
|
||||
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
|
||||
std::cout << "\n Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
|
||||
<< total_blocks << " blocks active)" << std::endl;
|
||||
|
||||
// ==================================================================
|
||||
// Validation (only when -v=1)
|
||||
// ==================================================================
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
std::cout << "\n--- Performing CPU validation ---" << std::endl;
|
||||
|
||||
// CPU golden: block map + VSA LUT
|
||||
std::cout << "Building Sparge block map (CPU golden)..." << std::endl;
|
||||
sparge::SpargeParams p;
|
||||
p.BLKQ = static_cast<int>(BLKQ);
|
||||
p.BLKK = static_cast<int>(BLKK);
|
||||
p.simthreshd1 = simthreshd1;
|
||||
p.cdfthreshd = cdfthreshd;
|
||||
p.topk = topk;
|
||||
p.i_perm = i_perm;
|
||||
|
||||
ck_tile::HostTensor<uint8_t> block_relation_onehot =
|
||||
sparge::build_block_map_meansim(q_host, k_host, p);
|
||||
|
||||
std::cout << "Converting block map to VSA LUT (delta, CPU)..." << std::endl;
|
||||
auto vsa_lut_cpu = sparge::block_map_to_vsa_lut_delta(block_relation_onehot);
|
||||
|
||||
// Validate block map
|
||||
std::cout << "\n--- Validating GPU block map vs CPU golden ---" << std::endl;
|
||||
{
|
||||
std::size_t bmap_mismatches = 0;
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
|
||||
{
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
if(block_map_gpu(b, h, qb, kb) !=
|
||||
block_relation_onehot(b, h, qb, kb))
|
||||
{
|
||||
bmap_mismatches++;
|
||||
if(bmap_mismatches <= 10)
|
||||
{
|
||||
std::cout
|
||||
<< " block_map mismatch at [" << b << "," << h << ","
|
||||
<< qb << "," << kb
|
||||
<< "]: GPU="
|
||||
<< static_cast<int>(block_map_gpu(b, h, qb, kb))
|
||||
<< " CPU="
|
||||
<< static_cast<int>(
|
||||
block_relation_onehot(b, h, qb, kb))
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::cout << " Block map mismatches: " << bmap_mismatches << " / "
|
||||
<< (batch * nhead * num_q_blocks * num_k_blocks) << std::endl;
|
||||
if(bmap_mismatches > 0)
|
||||
{
|
||||
std::cout << ">>> GPU BLOCK MAP VALIDATION FAILED <<<" << std::endl;
|
||||
pass = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << ">>> GPU BLOCK MAP VALIDATION PASSED <<<" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// Validate VSA LUT
|
||||
std::cout << "\n--- Validating GPU VSA LUT vs CPU golden ---" << std::endl;
|
||||
{
|
||||
std::size_t lut_mismatches = 0;
|
||||
std::size_t valid_mismatches = 0;
|
||||
for(ck_tile::index_t b = 0; b < batch; ++b)
|
||||
{
|
||||
for(ck_tile::index_t h = 0; h < nhead; ++h)
|
||||
{
|
||||
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
|
||||
{
|
||||
if(vsa_lut_gpu.valid_block_num(b, h, qb) !=
|
||||
vsa_lut_cpu.valid_block_num(b, h, qb))
|
||||
{
|
||||
valid_mismatches++;
|
||||
if(valid_mismatches <= 5)
|
||||
{
|
||||
std::cout
|
||||
<< " valid_block_num mismatch at [" << b << "," << h
|
||||
<< "," << qb
|
||||
<< "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb)
|
||||
<< " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb)
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
if(vsa_lut_gpu.lut(b, h, qb, kb) !=
|
||||
vsa_lut_cpu.lut(b, h, qb, kb))
|
||||
{
|
||||
lut_mismatches++;
|
||||
if(lut_mismatches <= 10)
|
||||
{
|
||||
std::cout
|
||||
<< " LUT mismatch at [" << b << "," << h << "," << qb
|
||||
<< "," << kb
|
||||
<< "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb)
|
||||
<< " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb)
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
std::cout << " LUT mismatches: " << lut_mismatches << std::endl;
|
||||
std::cout << " valid_block_num mismatches: " << valid_mismatches << std::endl;
|
||||
if(lut_mismatches == 0 && valid_mismatches == 0)
|
||||
{
|
||||
std::cout << ">>> GPU VSA LUT VALIDATION PASSED <<<" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << ">>> GPU VSA LUT VALIDATION FAILED <<<" << std::endl;
|
||||
pass = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Validate attention output
|
||||
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
std::cout << "Computing reference output..." << std::endl;
|
||||
std::cout << "\nComputing reference attention output..." << std::endl;
|
||||
auto q_ref = to_bhsd(q_host, i_perm);
|
||||
auto k_ref = to_bhsd(k_host, i_perm);
|
||||
auto v_ref = to_bhsd(v_host, i_perm);
|
||||
|
||||
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
|
||||
ck_tile::reference_blocked_attention<T, uint8_t>(
|
||||
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
|
||||
|
||||
@@ -374,7 +510,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "\nValidation results:" << std::endl;
|
||||
std::cout << "\nAttention validation results:" << std::endl;
|
||||
std::cout << " Max absolute difference: " << max_diff << std::endl;
|
||||
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
|
||||
std::cout << " Number of mismatches: " << num_errors << " / "
|
||||
|
||||
@@ -0,0 +1,195 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Pipeline_>
|
||||
struct SpargeBlockMapKernel
|
||||
{
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
|
||||
static constexpr index_t kBlockSize = Pipeline::kBlockSize;
|
||||
static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
|
||||
|
||||
using QDataType = typename Pipeline::QDataType;
|
||||
using KDataType = typename Pipeline::KDataType;
|
||||
|
||||
static constexpr index_t kM0 = Pipeline::kM0;
|
||||
static constexpr index_t kN0 = Pipeline::kN0;
|
||||
static constexpr index_t D = Pipeline::D;
|
||||
|
||||
static constexpr index_t kAlignment = 16 / sizeof(QDataType);
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
|
||||
index_t seqlen_q;
|
||||
index_t seqlen_k;
|
||||
index_t hdim_q;
|
||||
|
||||
index_t nhead_q;
|
||||
index_t nhead_ratio_qk;
|
||||
|
||||
index_t stride_q;
|
||||
index_t stride_k;
|
||||
index_t nhead_stride_q;
|
||||
index_t nhead_stride_k;
|
||||
index_t batch_stride_q;
|
||||
index_t batch_stride_k;
|
||||
|
||||
float simthreshd1;
|
||||
float cdfthreshd;
|
||||
float topk;
|
||||
float scale;
|
||||
|
||||
void* block_map_ptr;
|
||||
void* lut_ptr;
|
||||
void* valid_block_num_ptr;
|
||||
|
||||
index_t N_k;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
index_t seqlen_q,
|
||||
index_t seqlen_k,
|
||||
index_t hdim_q,
|
||||
index_t nhead_q,
|
||||
index_t nhead_ratio_qk,
|
||||
index_t stride_q,
|
||||
index_t stride_k,
|
||||
index_t nhead_stride_q,
|
||||
index_t nhead_stride_k,
|
||||
index_t batch_stride_q,
|
||||
index_t batch_stride_k,
|
||||
float simthreshd1,
|
||||
float cdfthreshd,
|
||||
float topk,
|
||||
float scale,
|
||||
void* block_map_ptr,
|
||||
void* lut_ptr,
|
||||
void* valid_block_num_ptr)
|
||||
{
|
||||
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
|
||||
return Kargs{q_ptr,
|
||||
k_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
nhead_q,
|
||||
nhead_ratio_qk,
|
||||
stride_q,
|
||||
stride_k,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
simthreshd1,
|
||||
cdfthreshd,
|
||||
topk,
|
||||
scale,
|
||||
block_map_ptr,
|
||||
lut_ptr,
|
||||
valid_block_num_ptr,
|
||||
N_k};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_q, index_t seqlen_q)
|
||||
{
|
||||
const index_t Q_blk = integer_divide_ceil(seqlen_q, kM0);
|
||||
return dim3(Q_blk, nhead_q, batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
const index_t qb = static_cast<index_t>(blockIdx.x);
|
||||
const index_t hq = static_cast<index_t>(blockIdx.y);
|
||||
const index_t b = static_cast<index_t>(blockIdx.z);
|
||||
|
||||
const index_t hk = hq / kargs.nhead_ratio_qk;
|
||||
|
||||
// Q pointer for this (batch, head, q_block)
|
||||
const auto* q_base = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
|
||||
b * kargs.batch_stride_q + hq * kargs.nhead_stride_q +
|
||||
qb * kM0 * kargs.stride_q;
|
||||
|
||||
// K pointer for this (batch, head_k)
|
||||
const auto* k_base = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
b * kargs.batch_stride_k + hk * kargs.nhead_stride_k;
|
||||
|
||||
// Q DRAM view with OOB padding
|
||||
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_base,
|
||||
make_tuple(kargs.seqlen_q - qb * kM0, D),
|
||||
make_tuple(kargs.stride_q, 1),
|
||||
number<kAlignment>{},
|
||||
number<1>{});
|
||||
const auto q_dram = pad_tensor_view(
|
||||
q_dram_naive, make_tuple(number<kM0>{}, number<D>{}), sequence<true, false>{});
|
||||
|
||||
auto q_window = make_tile_window(q_dram,
|
||||
make_tuple(number<kM0>{}, number<D>{}),
|
||||
{0, 0},
|
||||
Pipeline::MakeQBlockDistribution());
|
||||
|
||||
// K DRAM view with OOB padding
|
||||
const auto k_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global>(k_base,
|
||||
make_tuple(kargs.seqlen_k, D),
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
number<kAlignment>{},
|
||||
number<1>{});
|
||||
const auto k_dram = pad_tensor_view(
|
||||
k_dram_naive, make_tuple(number<kN0>{}, number<D>{}), sequence<true, false>{});
|
||||
|
||||
auto k_window = make_tile_window(k_dram,
|
||||
make_tuple(number<kN0>{}, number<D>{}),
|
||||
{0, 0},
|
||||
Pipeline::MakeKBlockDistribution());
|
||||
|
||||
// Output pointers for this (batch, head, q_block)
|
||||
const index_t N_k = kargs.N_k;
|
||||
const index_t bmap_offset =
|
||||
(b * kargs.nhead_q + hq) * integer_divide_ceil(kargs.seqlen_q, kM0) * N_k + qb * N_k;
|
||||
auto* bmap_ptr = reinterpret_cast<uint8_t*>(kargs.block_map_ptr) + bmap_offset;
|
||||
|
||||
int32_t* lut_out = nullptr;
|
||||
int32_t* valid_out = nullptr;
|
||||
if(kargs.lut_ptr != nullptr)
|
||||
{
|
||||
lut_out = reinterpret_cast<int32_t*>(kargs.lut_ptr) + bmap_offset;
|
||||
const index_t valid_offset =
|
||||
(b * kargs.nhead_q + hq) * integer_divide_ceil(kargs.seqlen_q, kM0) + qb;
|
||||
valid_out = reinterpret_cast<int32_t*>(kargs.valid_block_num_ptr) + valid_offset;
|
||||
}
|
||||
|
||||
// Shared memory
|
||||
__shared__ char smem[Pipeline::GetSmemSize()];
|
||||
|
||||
Pipeline{}(q_window,
|
||||
k_window,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
qb,
|
||||
N_k,
|
||||
kargs.nhead_ratio_qk,
|
||||
kargs.simthreshd1,
|
||||
kargs.cdfthreshd,
|
||||
kargs.topk,
|
||||
kargs.scale,
|
||||
bmap_ptr,
|
||||
lut_out,
|
||||
valid_out,
|
||||
static_cast<void*>(smem));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,521 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_>
|
||||
struct SpargeBlockMapPipeline
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t D = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t NumWarps = BlockFmhaShape::NumWarps;
|
||||
static constexpr index_t WarpSize = get_warp_size();
|
||||
|
||||
static constexpr index_t KPerThread = 16 / sizeof(QDataType);
|
||||
static constexpr index_t KThreads = D / KPerThread;
|
||||
static constexpr index_t SeqThreadPerWarp = WarpSize / KThreads;
|
||||
static constexpr index_t MPerThread = kM0 / (SeqThreadPerWarp * NumWarps);
|
||||
static constexpr index_t NPerThread = kN0 / (SeqThreadPerWarp * NumWarps);
|
||||
|
||||
static constexpr index_t kBlockPerCu = 1;
|
||||
static constexpr index_t kMaxKBlocks = 1024;
|
||||
|
||||
// LDS layout (non-overlapping, all used simultaneously in Phase 2):
|
||||
// [0 .. kReduceBytes) cross-warp reduction scratch
|
||||
// [kScoreOffset ..) scores[N_k]
|
||||
// [kBmapOffset ..) block_map[N_k]
|
||||
// [kSmallOffset ..) Phase 3 argmax scratch (2*NumWarps floats)
|
||||
static constexpr index_t kReduceBytes = NumWarps * D * sizeof(float);
|
||||
static constexpr index_t kScoreOffset = kReduceBytes;
|
||||
static constexpr index_t kBmapOffset = kScoreOffset + kMaxKBlocks * sizeof(float);
|
||||
static constexpr index_t kSmallOffset = kBmapOffset + kMaxKBlocks * sizeof(uint8_t);
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return kSmallOffset + 2 * NumWarps * sizeof(float);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQBlockDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<MPerThread, NumWarps, SeqThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKBlockDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NPerThread, NumWarps, SeqThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
// Extract tile data into a local float array via static_for (compile-time indices).
|
||||
template <index_t BufSize, typename Tile>
|
||||
CK_TILE_DEVICE static void tile_to_float(const Tile& tile, float (&out)[BufSize])
|
||||
{
|
||||
static_assert(Tile::get_thread_buffer_size() == BufSize);
|
||||
const auto& buf = tile.get_thread_buffer();
|
||||
static_for<0, BufSize, 1>{}([&](auto i) { out[i.value] = type_convert<float>(buf[i]); });
|
||||
}
|
||||
|
||||
// Column-wise (dim=0) sum: accumulate SeqPerThread rows into KPerThread partial sums,
|
||||
// then xor-shuffle across m_idx within warp.
|
||||
template <index_t SeqPerThread>
|
||||
CK_TILE_DEVICE static void column_reduce_thread_and_warp(const float* __restrict__ data,
|
||||
float (&col_acc)[KPerThread])
|
||||
{
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
col_acc[k] = 0.f;
|
||||
|
||||
for(index_t m = 0; m < SeqPerThread; ++m)
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
col_acc[k] += data[m * KPerThread + k];
|
||||
|
||||
for(index_t stride = KThreads; stride < WarpSize; stride *= 2)
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
col_acc[k] += warp_shuffle(col_acc[k], __lane_id() ^ stride);
|
||||
}
|
||||
|
||||
// Cross-warp LDS reduction for column sums.
|
||||
CK_TILE_DEVICE static void column_reduce_cross_warp(float (&col_acc)[KPerThread],
|
||||
float* __restrict__ smem_reduce)
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
const index_t warp_id = tid / WarpSize;
|
||||
const index_t lane_id = tid % WarpSize;
|
||||
const index_t k_idx = lane_id % KThreads;
|
||||
const index_t m_idx = lane_id / KThreads;
|
||||
|
||||
if(m_idx == 0)
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
smem_reduce[warp_id * D + k_idx * KPerThread + k] = col_acc[k];
|
||||
__syncthreads();
|
||||
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
col_acc[k] = 0.f;
|
||||
for(index_t w = 0; w < NumWarps; ++w)
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
col_acc[k] += smem_reduce[w * D + k_idx * KPerThread + k];
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Compute ||v||^2 per row: sum along KPerThread then xor-shuffle across k_idx.
|
||||
template <index_t SeqPerThread>
|
||||
CK_TILE_DEVICE static void row_reduce_sq_norm(const float* __restrict__ data,
|
||||
float (&row_norms)[SeqPerThread],
|
||||
index_t actual_seq)
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
const index_t warp_id = tid / WarpSize;
|
||||
const index_t m_idx = (tid % WarpSize) / KThreads;
|
||||
|
||||
for(index_t m = 0; m < SeqPerThread; ++m)
|
||||
{
|
||||
float sq = 0.f;
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
{
|
||||
float v = data[m * KPerThread + k];
|
||||
sq += v * v;
|
||||
}
|
||||
for(index_t stride = 1; stride < KThreads; stride *= 2)
|
||||
sq += warp_shuffle(sq, __lane_id() ^ stride);
|
||||
|
||||
index_t gsq = m * (SeqThreadPerWarp * NumWarps) + warp_id * SeqThreadPerWarp + m_idx;
|
||||
row_norms[m] = (gsq < actual_seq) ? sq : 0.f;
|
||||
}
|
||||
}
|
||||
|
||||
// Column reduce of normalised rows: sum_hat[d] = sum_i data[i,d] / ||data[i,:]||.
|
||||
template <index_t SeqPerThread>
|
||||
CK_TILE_DEVICE static void column_reduce_normalised(const float* __restrict__ data,
|
||||
const float* __restrict__ row_norms,
|
||||
float (&col_acc)[KPerThread],
|
||||
index_t actual_seq)
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
const index_t warp_id = tid / WarpSize;
|
||||
const index_t m_idx = (tid % WarpSize) / KThreads;
|
||||
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
col_acc[k] = 0.f;
|
||||
|
||||
for(index_t m = 0; m < SeqPerThread; ++m)
|
||||
{
|
||||
float inv_norm = (row_norms[m] > 0.f) ? (1.0f / __builtin_sqrtf(row_norms[m])) : 0.f;
|
||||
index_t gsq = m * (SeqThreadPerWarp * NumWarps) + warp_id * SeqThreadPerWarp + m_idx;
|
||||
if(gsq < actual_seq)
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
col_acc[k] += data[m * KPerThread + k] * inv_norm;
|
||||
}
|
||||
|
||||
for(index_t stride = KThreads; stride < WarpSize; stride *= 2)
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
col_acc[k] += warp_shuffle(col_acc[k], __lane_id() ^ stride);
|
||||
}
|
||||
|
||||
// Scalar reduce across k_idx lanes (within warp).
|
||||
CK_TILE_DEVICE static float reduce_across_k(float v)
|
||||
{
|
||||
for(index_t stride = 1; stride < KThreads; stride *= 2)
|
||||
v += warp_shuffle(v, __lane_id() ^ stride);
|
||||
return v;
|
||||
}
|
||||
|
||||
// Full-block scalar reduce (warp xor + cross-warp LDS).
|
||||
CK_TILE_DEVICE static float block_reduce_sum(float v, float* smem_small)
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
const index_t warp_id = tid / WarpSize;
|
||||
const index_t lane_id = tid % WarpSize;
|
||||
|
||||
for(index_t stride = 1; stride < WarpSize; stride *= 2)
|
||||
v += warp_shuffle(v, __lane_id() ^ stride);
|
||||
if(lane_id == 0)
|
||||
smem_small[warp_id] = v;
|
||||
__syncthreads();
|
||||
if(tid == 0)
|
||||
{
|
||||
float s = 0.f;
|
||||
for(index_t w = 0; w < NumWarps; ++w)
|
||||
s += smem_small[w];
|
||||
smem_small[0] = s;
|
||||
}
|
||||
__syncthreads();
|
||||
return smem_small[0];
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static float block_reduce_max(float v, float* smem_small)
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
const index_t warp_id = tid / WarpSize;
|
||||
const index_t lane_id = tid % WarpSize;
|
||||
|
||||
for(index_t stride = 1; stride < WarpSize; stride *= 2)
|
||||
v = max(v, warp_shuffle(v, __lane_id() ^ stride));
|
||||
if(lane_id == 0)
|
||||
smem_small[warp_id] = v;
|
||||
__syncthreads();
|
||||
if(tid == 0)
|
||||
{
|
||||
float s = smem_small[0];
|
||||
for(index_t w = 1; w < NumWarps; ++w)
|
||||
s = max(s, smem_small[w]);
|
||||
smem_small[0] = s;
|
||||
}
|
||||
__syncthreads();
|
||||
return smem_small[0];
|
||||
}
|
||||
|
||||
// ======================================================================
|
||||
template <typename QWindowType, typename KWindowType>
|
||||
CK_TILE_DEVICE void operator()(const QWindowType& q_window_in,
|
||||
const KWindowType& k_window_in,
|
||||
index_t seqlen_q,
|
||||
index_t seqlen_k,
|
||||
index_t qb,
|
||||
index_t N_k,
|
||||
index_t /*nhead_ratio_qk*/,
|
||||
float simthreshd1,
|
||||
float cdfthreshd,
|
||||
float topk,
|
||||
float scale,
|
||||
uint8_t* block_map_ptr,
|
||||
int32_t* lut_ptr,
|
||||
int32_t* valid_block_num_ptr,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
const index_t tid = static_cast<index_t>(threadIdx.x);
|
||||
|
||||
auto* smem_float = reinterpret_cast<float*>(smem_ptr);
|
||||
auto* smem_scores =
|
||||
reinterpret_cast<float*>(reinterpret_cast<char*>(smem_ptr) + kScoreOffset);
|
||||
auto* smem_bmap =
|
||||
reinterpret_cast<uint8_t*>(reinterpret_cast<char*>(smem_ptr) + kBmapOffset);
|
||||
auto* smem_small =
|
||||
reinterpret_cast<float*>(reinterpret_cast<char*>(smem_ptr) + kSmallOffset);
|
||||
|
||||
const index_t bs_q = min(static_cast<index_t>(kM0), seqlen_q - qb * kM0);
|
||||
const float inv_bs_q = (bs_q > 0) ? (1.0f / static_cast<float>(bs_q)) : 0.f;
|
||||
|
||||
// ==================================================================
|
||||
// Phase 1: Q Block Statistics
|
||||
// ==================================================================
|
||||
auto q_tile = load_tile(q_window_in);
|
||||
|
||||
float q_data[MPerThread * KPerThread];
|
||||
tile_to_float<MPerThread * KPerThread>(q_tile, q_data);
|
||||
|
||||
// 1a. L2 norm per token
|
||||
float psq[MPerThread];
|
||||
row_reduce_sq_norm<MPerThread>(q_data, psq, bs_q);
|
||||
|
||||
// 1b. Column sum -> mean
|
||||
float pooled_q_mean[KPerThread];
|
||||
column_reduce_thread_and_warp<MPerThread>(q_data, pooled_q_mean);
|
||||
column_reduce_cross_warp(pooled_q_mean, smem_float);
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
pooled_q_mean[k] *= inv_bs_q;
|
||||
|
||||
// 1c. Normalised sum_hat
|
||||
float sum_hat[KPerThread];
|
||||
column_reduce_normalised<MPerThread>(q_data, psq, sum_hat, bs_q);
|
||||
column_reduce_cross_warp(sum_hat, smem_float);
|
||||
|
||||
// 1d. sim_q = ||sum_hat||^2 / bs_q^2
|
||||
float sh_sq = 0.f;
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
sh_sq += sum_hat[k] * sum_hat[k];
|
||||
sh_sq = reduce_across_k(sh_sq);
|
||||
const float denom_q = static_cast<float>(bs_q) * static_cast<float>(bs_q);
|
||||
const bool sim_q = (denom_q > 0.f) && ((sh_sq / denom_q) > simthreshd1);
|
||||
|
||||
// Not similar → force all K blocks ON, early exit
|
||||
if(!sim_q)
|
||||
{
|
||||
for(index_t i = tid; i < N_k; i += kBlockSize)
|
||||
block_map_ptr[i] = 1;
|
||||
|
||||
if(lut_ptr != nullptr && tid == 0)
|
||||
{
|
||||
int32_t valid = 0, prev = 0;
|
||||
for(index_t kb = 0; kb < N_k; ++kb)
|
||||
{
|
||||
lut_ptr[valid] = static_cast<int32_t>(kb) - prev;
|
||||
prev = static_cast<int32_t>(kb);
|
||||
++valid;
|
||||
}
|
||||
for(index_t i = valid; i < N_k; ++i)
|
||||
lut_ptr[i] = 0;
|
||||
*valid_block_num_ptr = valid;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// Phase 2: K Block Loop
|
||||
// ==================================================================
|
||||
for(index_t i = tid; i < N_k; i += kBlockSize)
|
||||
smem_bmap[i] = 0;
|
||||
__syncthreads();
|
||||
|
||||
auto k_window = k_window_in;
|
||||
|
||||
for(index_t kb = 0; kb < N_k; ++kb)
|
||||
{
|
||||
const index_t bs_k = min(static_cast<index_t>(kN0), seqlen_k - kb * kN0);
|
||||
const float inv_bs_k = (bs_k > 0) ? (1.0f / static_cast<float>(bs_k)) : 0.f;
|
||||
|
||||
auto k_tile = load_tile(k_window);
|
||||
|
||||
float k_data[NPerThread * KPerThread];
|
||||
tile_to_float<NPerThread * KPerThread>(k_tile, k_data);
|
||||
|
||||
// K mean
|
||||
float pooled_k_mean[KPerThread];
|
||||
column_reduce_thread_and_warp<NPerThread>(k_data, pooled_k_mean);
|
||||
column_reduce_cross_warp(pooled_k_mean, smem_float);
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
pooled_k_mean[k] *= inv_bs_k;
|
||||
|
||||
// dot(pooled_q_mean, pooled_k_mean)
|
||||
float dot = 0.f;
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
dot += pooled_q_mean[k] * pooled_k_mean[k];
|
||||
dot = reduce_across_k(dot);
|
||||
|
||||
// K L2 norms + normalised sum_hat
|
||||
float k_psq[NPerThread];
|
||||
row_reduce_sq_norm<NPerThread>(k_data, k_psq, bs_k);
|
||||
|
||||
float k_sum_hat[KPerThread];
|
||||
column_reduce_normalised<NPerThread>(k_data, k_psq, k_sum_hat, bs_k);
|
||||
column_reduce_cross_warp(k_sum_hat, smem_float);
|
||||
|
||||
// sim_k
|
||||
float ksh_sq = 0.f;
|
||||
for(index_t k = 0; k < KPerThread; ++k)
|
||||
ksh_sq += k_sum_hat[k] * k_sum_hat[k];
|
||||
ksh_sq = reduce_across_k(ksh_sq);
|
||||
const float denom_k = static_cast<float>(bs_k) * static_cast<float>(bs_k);
|
||||
const bool sim_k = (denom_k > 0.f) && ((ksh_sq / denom_k) > simthreshd1);
|
||||
|
||||
if(tid == 0)
|
||||
{
|
||||
if(!sim_k)
|
||||
{
|
||||
smem_bmap[kb] = 1;
|
||||
smem_scores[kb] = -numeric<float>::infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
smem_scores[kb] = dot * scale;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
move_tile_window(k_window, {kN0, 0});
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// Phase 3: Softmax + Selection
|
||||
// ==================================================================
|
||||
|
||||
// max
|
||||
float lmax = -numeric<float>::infinity();
|
||||
for(index_t i = tid; i < N_k; i += kBlockSize)
|
||||
lmax = max(lmax, smem_scores[i]);
|
||||
const float max_score = block_reduce_max(lmax, smem_small);
|
||||
|
||||
// exp + sum
|
||||
float lsum = 0.f;
|
||||
for(index_t i = tid; i < N_k; i += kBlockSize)
|
||||
{
|
||||
float e = (smem_scores[i] > -numeric<float>::infinity())
|
||||
? __builtin_expf(smem_scores[i] - max_score)
|
||||
: 0.f;
|
||||
smem_scores[i] = e;
|
||||
lsum += e;
|
||||
}
|
||||
const float sum_exp = block_reduce_sum(lsum, smem_small);
|
||||
|
||||
// normalise
|
||||
const float inv_sum = (sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f;
|
||||
for(index_t i = tid; i < N_k; i += kBlockSize)
|
||||
smem_scores[i] *= inv_sum;
|
||||
__syncthreads();
|
||||
|
||||
// Selection: iterative argmax
|
||||
index_t num_to_select =
|
||||
(topk > 0.f)
|
||||
? max(static_cast<index_t>(1), static_cast<index_t>(topk * static_cast<float>(N_k)))
|
||||
: N_k;
|
||||
|
||||
float cumulative_prob = 0.f;
|
||||
for(index_t round = 0; round < num_to_select; ++round)
|
||||
{
|
||||
// thread-local argmax
|
||||
float best_val = -1.f;
|
||||
index_t best_idx = 0;
|
||||
for(index_t i = tid; i < N_k; i += kBlockSize)
|
||||
{
|
||||
if(smem_scores[i] > best_val || (smem_scores[i] == best_val && i < best_idx))
|
||||
{
|
||||
best_val = smem_scores[i];
|
||||
best_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
// warp argmax
|
||||
for(index_t stride = 1; stride < WarpSize; stride *= 2)
|
||||
{
|
||||
float rv = warp_shuffle(best_val, __lane_id() ^ stride);
|
||||
index_t ri = warp_shuffle(best_idx, __lane_id() ^ stride);
|
||||
if(rv > best_val || (rv == best_val && ri < best_idx))
|
||||
{
|
||||
best_val = rv;
|
||||
best_idx = ri;
|
||||
}
|
||||
}
|
||||
|
||||
// cross-warp argmax via LDS
|
||||
const index_t lane_id = tid % WarpSize;
|
||||
const index_t warp_id = tid / WarpSize;
|
||||
if(lane_id == 0)
|
||||
{
|
||||
smem_small[warp_id] = best_val;
|
||||
smem_small[NumWarps + warp_id] = bit_cast<float>(static_cast<int32_t>(best_idx));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if(tid == 0)
|
||||
{
|
||||
float bv = smem_small[0];
|
||||
index_t bi = bit_cast<int32_t>(smem_small[NumWarps]);
|
||||
for(index_t w = 1; w < NumWarps; ++w)
|
||||
{
|
||||
float wv = smem_small[w];
|
||||
index_t wi = bit_cast<int32_t>(smem_small[NumWarps + w]);
|
||||
if(wv > bv || (wv == bv && wi < bi))
|
||||
{
|
||||
bv = wv;
|
||||
bi = wi;
|
||||
}
|
||||
}
|
||||
smem_small[0] = bv;
|
||||
smem_small[1] = bit_cast<float>(static_cast<int32_t>(bi));
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float g_val = smem_small[0];
|
||||
index_t g_idx = bit_cast<int32_t>(smem_small[1]);
|
||||
|
||||
if(g_val <= 0.f)
|
||||
break;
|
||||
|
||||
if(tid == 0)
|
||||
{
|
||||
smem_bmap[g_idx] = 1;
|
||||
smem_scores[g_idx] = -1.f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if(topk > 0.f)
|
||||
{
|
||||
if(round + 1 >= num_to_select)
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
cumulative_prob += g_val;
|
||||
if(cumulative_prob >= cdfthreshd)
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// Write outputs to global memory
|
||||
// ==================================================================
|
||||
for(index_t i = tid; i < N_k; i += kBlockSize)
|
||||
block_map_ptr[i] = smem_bmap[i];
|
||||
|
||||
if(lut_ptr != nullptr && tid == 0)
|
||||
{
|
||||
int32_t valid = 0, prev = 0;
|
||||
for(index_t kb = 0; kb < N_k; ++kb)
|
||||
{
|
||||
if(smem_bmap[kb] != 0)
|
||||
{
|
||||
lut_ptr[valid] = static_cast<int32_t>(kb) - prev;
|
||||
prev = static_cast<int32_t>(kb);
|
||||
++valid;
|
||||
}
|
||||
}
|
||||
for(index_t i = valid; i < N_k; ++i)
|
||||
lut_ptr[i] = 0;
|
||||
*valid_block_num_ptr = valid;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user