fix extra host side operations.

This commit is contained in:
Gino Lu
2026-04-14 10:11:00 -04:00
parent d1d457b82a
commit c7e6e4f616
6 changed files with 163 additions and 540 deletions

View File

@@ -249,14 +249,12 @@ set(SPARGE_VSA_INSTANCES "tile_sparge_vsa_instances")
add_library(${SPARGE_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
${SPARGE_VSA_GEN_BLOBS}
${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp
)
target_include_directories(${SPARGE_VSA_INSTANCES} PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
)
set_source_files_properties(${SPARGE_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp PROPERTIES LANGUAGE HIP)
set_property(TARGET ${SPARGE_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE
@@ -273,7 +271,6 @@ set(SPARGE_BLOCKMAP_INSTANCES "tile_sparge_blockmap_instances")
add_library(${SPARGE_BLOCKMAP_INSTANCES} OBJECT EXCLUDE_FROM_ALL
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp
)
target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE
${CMAKE_CURRENT_LIST_DIR}
@@ -281,7 +278,6 @@ target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE
)
set_source_files_properties(
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp
PROPERTIES LANGUAGE HIP
)
set_property(TARGET ${SPARGE_BLOCKMAP_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})

View File

@@ -1,156 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "sparge_blockmap.h"
#include "sparge_blockmap_trek.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/device_memory.hpp"
#include <type_traits>
#include <cmath>
template <typename DataType_>
sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor<DataType_>& TQ,
const ck_tile::HostTensor<DataType_>& TK,
ck_tile::HostTensor<uint8_t>& block_map_out,
int batch,
int nhead_q,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
bool i_perm,
float simthreshd1,
float cdfthreshd,
float topk,
int blkq,
int blkk,
int log_level)
{
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
std::is_same_v<DataType_, ck_tile::bf16_t>,
"sparge_blockmap_gpu supports fp16/bf16 only.");
std::string data_type = "fp16";
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
{
data_type = "bf16";
}
const ck_tile::index_t num_q_blocks = ck_tile::integer_divide_ceil(seqlen_q, blkq);
const ck_tile::index_t num_k_blocks = ck_tile::integer_divide_ceil(seqlen_k, blkk);
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
// Allocate device memory
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
const std::size_t bmap_bytes =
static_cast<std::size_t>(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(uint8_t);
const std::size_t lut_bytes =
static_cast<std::size_t>(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(int32_t);
const std::size_t valid_bytes =
static_cast<std::size_t>(batch) * nhead_q * num_q_blocks * sizeof(int32_t);
ck_tile::DeviceMem bmap_buf(bmap_bytes);
ck_tile::DeviceMem lut_buf(lut_bytes);
ck_tile::DeviceMem valid_buf(valid_bytes);
q_buf.ToDevice(TQ.data());
k_buf.ToDevice(TK.data());
bmap_buf.SetZero();
lut_buf.SetZero();
valid_buf.SetZero();
// Compute strides (assumes BHSD if i_perm, BSHD otherwise)
const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead_q * hdim_q;
const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q;
const ck_tile::index_t nhead_stride_q =
i_perm ? static_cast<ck_tile::index_t>(seqlen_q) * hdim_q : hdim_q;
const ck_tile::index_t nhead_stride_k =
i_perm ? static_cast<ck_tile::index_t>(seqlen_k) * hdim_q : hdim_q;
const ck_tile::index_t batch_stride_q =
static_cast<ck_tile::index_t>(nhead_q) * seqlen_q * hdim_q;
const ck_tile::index_t batch_stride_k =
static_cast<ck_tile::index_t>(nhead_k) * seqlen_k * hdim_q;
ck_tile::stream_config stream_config{nullptr, false, log_level, 0, 1, false};
sparge_blockmap_args args;
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
args.batch = batch;
args.seqlen_q = seqlen_q;
args.seqlen_k = seqlen_k;
args.hdim_q = hdim_q;
args.nhead_q = nhead_q;
args.nhead_k = nhead_k;
args.stride_q = stride_q;
args.stride_k = stride_k;
args.nhead_stride_q = nhead_stride_q;
args.nhead_stride_k = nhead_stride_k;
args.batch_stride_q = batch_stride_q;
args.batch_stride_k = batch_stride_k;
args.simthreshd1 = simthreshd1;
args.cdfthreshd = cdfthreshd;
args.topk = topk;
args.scale = scale;
args.block_map_ptr = bmap_buf.GetDeviceBuffer();
args.lut_ptr = lut_buf.GetDeviceBuffer();
args.valid_block_num_ptr = valid_buf.GetDeviceBuffer();
sparge_blockmap_traits traits;
traits.data_type = data_type;
traits.hdim_q = hdim_q;
sparge_blockmap_fwd(traits, args, stream_config);
// Copy results back to host
bmap_buf.FromDevice(block_map_out.data(), bmap_bytes);
sparge::VSALut vsa_lut{
ck_tile::HostTensor<int32_t>({batch, nhead_q, num_q_blocks, num_k_blocks}),
ck_tile::HostTensor<int32_t>({batch, nhead_q, num_q_blocks}),
};
lut_buf.FromDevice(vsa_lut.lut.data(), lut_bytes);
valid_buf.FromDevice(vsa_lut.valid_block_num.data(), valid_bytes);
return vsa_lut;
}
// Explicit template instantiations
template sparge::VSALut
sparge_blockmap_gpu<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<ck_tile::half_t>&,
ck_tile::HostTensor<uint8_t>&,
int,
int,
int,
int,
int,
int,
bool,
float,
float,
float,
int,
int,
int);
template sparge::VSALut
sparge_blockmap_gpu<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<ck_tile::bf16_t>&,
ck_tile::HostTensor<uint8_t>&,
int,
int,
int,
int,
int,
int,
bool,
float,
float,
float,
int,
int,
int);

View File

@@ -1,26 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdint>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "sparge_tool.hpp"
template <typename DataType_>
sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor<DataType_>& TQ,
const ck_tile::HostTensor<DataType_>& TK,
ck_tile::HostTensor<uint8_t>& block_map_out,
int batch,
int nhead_q,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
bool i_perm,
float simthreshd1,
float cdfthreshd,
float topk,
int blkq,
int blkk,
int log_level = 0);

View File

@@ -1,23 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention
// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention (all-in-device)
#include <iostream>
#include <vector>
#include <cmath>
#include <random>
#include <string>
#include <algorithm>
#include <numeric>
#include <chrono>
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "vsa_sparge_attention.h"
#include "sparge_blockmap.h"
#include "sparge_blockmap_trek.hpp"
#include "fmha_fwd_trek.hpp"
#include "sparge_tool.hpp"
// ============================================================================
@@ -192,7 +186,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
<< ", topk=" << topk << ")" << std::endl;
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
// Create host tensors
// Create host tensors and fill with random data
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
@@ -206,119 +200,157 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
// ==================================================================
// GPU: Build block map + VSA LUT in one kernel (always run)
// Allocate device memory once, HtoD once
// ==================================================================
std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl;
ck_tile::HostTensor<uint8_t> block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks});
auto vsa_lut_gpu = sparge_blockmap_gpu<T>(q_host,
k_host,
block_map_gpu,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
i_perm,
simthreshd1,
cdfthreshd,
topk,
static_cast<int>(BLKQ),
static_cast<int>(BLKK),
0);
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(output_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
v_buf.ToDevice(v_host.data());
const std::size_t bmap_bytes =
static_cast<std::size_t>(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(uint8_t);
const std::size_t lut_bytes =
static_cast<std::size_t>(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(int32_t);
const std::size_t valid_bytes =
static_cast<std::size_t>(batch) * nhead * num_q_blocks * sizeof(int32_t);
ck_tile::DeviceMem bmap_buf(bmap_bytes);
ck_tile::DeviceMem lut_buf(lut_bytes);
ck_tile::DeviceMem valid_buf(valid_bytes);
bmap_buf.SetZero();
lut_buf.SetZero();
valid_buf.SetZero();
// ==================================================================
// VSA sparse attention kernel (always run)
// Common stride calculations
// ==================================================================
assert(nhead % nhead_k == 0);
const float scale_s = 1.0f / std::sqrt(static_cast<float>(hdim_q));
const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead * hdim_q;
const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q;
const ck_tile::index_t stride_v = i_perm ? hdim_v : nhead_k * hdim_v;
const ck_tile::index_t stride_o = o_perm ? hdim_v : nhead * hdim_v;
const ck_tile::index_t nhead_stride_q = i_perm ? seqlen_q * hdim_q : hdim_q;
const ck_tile::index_t nhead_stride_k = i_perm ? seqlen_k * hdim_q : hdim_q;
const ck_tile::index_t nhead_stride_v = i_perm ? seqlen_k * hdim_v : hdim_v;
const ck_tile::index_t nhead_stride_o = o_perm ? seqlen_q * hdim_v : hdim_v;
const ck_tile::index_t batch_stride_q = nhead * seqlen_q * hdim_q;
const ck_tile::index_t batch_stride_k = nhead_k * seqlen_k * hdim_q;
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * seqlen_k;
const ck_tile::index_t batch_stride_o = nhead * seqlen_q * hdim_v;
std::string data_type = "fp16";
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
data_type = "bf16";
std::string msk_str = "0";
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
// ==================================================================
// GPU: Build block map + VSA LUT (always run, device-only)
// ==================================================================
std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl;
{
sparge_blockmap_args args;
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
args.batch = batch;
args.seqlen_q = seqlen_q;
args.seqlen_k = seqlen_k;
args.hdim_q = hdim_q;
args.nhead_q = nhead;
args.nhead_k = nhead_k;
args.stride_q = stride_q;
args.stride_k = stride_k;
args.nhead_stride_q = nhead_stride_q;
args.nhead_stride_k = nhead_stride_k;
args.batch_stride_q = batch_stride_q;
args.batch_stride_k = batch_stride_k;
args.simthreshd1 = simthreshd1;
args.cdfthreshd = cdfthreshd;
args.topk = topk;
args.scale = scale_s;
args.block_map_ptr = bmap_buf.GetDeviceBuffer();
args.lut_ptr = lut_buf.GetDeviceBuffer();
args.valid_block_num_ptr = valid_buf.GetDeviceBuffer();
sparge_blockmap_traits traits;
traits.data_type = data_type;
traits.hdim_q = hdim_q;
sparge_blockmap_fwd(traits, args, ck_tile::stream_config{});
}
// ==================================================================
// VSA sparse attention kernel (always run, LUT stays on device)
// ==================================================================
std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl;
try
{
if(kname)
{
vsa_sparge_attention<T>(q_host,
k_host,
v_host,
vsa_lut_gpu.lut,
vsa_lut_gpu.valid_block_num,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
1);
}
fmha_vsa_fwd_args fmha_args;
fmha_args.q_ptr = q_buf.GetDeviceBuffer();
fmha_args.k_ptr = k_buf.GetDeviceBuffer();
fmha_args.v_ptr = v_buf.GetDeviceBuffer();
fmha_args.lut_ptr = lut_buf.GetDeviceBuffer();
fmha_args.valid_block_num_ptr = valid_buf.GetDeviceBuffer();
fmha_args.o_ptr = o_buf.GetDeviceBuffer();
fmha_args.batch = batch;
fmha_args.seqlen_q = seqlen_q;
fmha_args.seqlen_k = seqlen_k;
fmha_args.max_seqlen_q = seqlen_q;
fmha_args.hdim_q = hdim_q;
fmha_args.hdim_v = hdim_v;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead_k;
fmha_args.scale_s = scale_s;
fmha_args.stride_q = stride_q;
fmha_args.stride_k = stride_k;
fmha_args.stride_v = stride_v;
fmha_args.stride_o = stride_o;
fmha_args.nhead_stride_q = nhead_stride_q;
fmha_args.nhead_stride_k = nhead_stride_k;
fmha_args.nhead_stride_v = nhead_stride_v;
fmha_args.nhead_stride_o = nhead_stride_o;
fmha_args.batch_stride_q = batch_stride_q;
fmha_args.batch_stride_k = batch_stride_k;
fmha_args.batch_stride_v = batch_stride_v;
fmha_args.batch_stride_o = batch_stride_o;
fmha_args.window_size_left = mask.left;
fmha_args.window_size_right = mask.right;
fmha_args.mask_type = static_cast<ck_tile::index_t>(mask.type);
for(int i = 0; i < warmup; ++i)
{
vsa_sparge_attention<T>(q_host,
k_host,
v_host,
vsa_lut_gpu.lut,
vsa_lut_gpu.valid_block_num,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
0);
}
fmha_vsa_fwd_traits fmha_traits;
fmha_traits.hdim_q = hdim_q;
fmha_traits.hdim_v = hdim_v;
fmha_traits.data_type = data_type;
fmha_traits.is_v_rowmajor = true;
fmha_traits.mask_type = mask.type;
[[maybe_unused]] auto sync_status1 = hipDeviceSynchronize();
auto start = std::chrono::high_resolution_clock::now();
ck_tile::stream_config stream_config{nullptr,
true,
/* log_level = */ kname ? 1 : 0,
warmup,
repeat,
false};
for(int i = 0; i < repeat; ++i)
{
vsa_sparge_attention<T>(q_host,
k_host,
v_host,
vsa_lut_gpu.lut,
vsa_lut_gpu.valid_block_num,
output_host,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
i_perm,
o_perm,
seqlen_q,
seqlen_k,
0);
}
float avg_time_ms = sparge_vsa_fwd(fmha_traits, fmha_args, stream_config);
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
auto end = std::chrono::high_resolution_clock::now();
double avg_time_ms =
std::chrono::duration<double, std::milli>(end - start).count() / repeat;
std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<"
<< std::endl;
std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<"
<< std::endl;
}
catch(const std::exception& e)
{
std::cerr << "Error during kernel execution: " << e.what() << std::endl;
return false;
}
// DtoH: attention output (always needed)
o_buf.FromDevice(output_host.data(), output_host.get_element_space_size_in_bytes());
// DtoH: block_map (needed for sparsity stats and validation)
ck_tile::HostTensor<uint8_t> block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks});
bmap_buf.FromDevice(block_map_gpu.data(), bmap_bytes);
// ==================================================================
// Sparsity statistics (always run, pure CPU read of HostTensor)
// Sparsity statistics (pure CPU, reads block_map HostTensor)
// ==================================================================
std::size_t total_blocks = 0;
std::size_t active_blocks = 0;
@@ -366,6 +398,14 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
std::cout << "Converting block map to VSA LUT (delta, CPU)..." << std::endl;
auto vsa_lut_cpu = sparge::block_map_to_vsa_lut_delta(block_relation_onehot);
// DtoH: LUT + valid_block_num (only for validation)
sparge::VSALut vsa_lut_gpu{
ck_tile::HostTensor<int32_t>({batch, nhead, num_q_blocks, num_k_blocks}),
ck_tile::HostTensor<int32_t>({batch, nhead, num_q_blocks}),
};
lut_buf.FromDevice(vsa_lut_gpu.lut.data(), lut_bytes);
valid_buf.FromDevice(vsa_lut_gpu.valid_block_num.data(), valid_bytes);
// Validate block map
std::cout << "\n--- Validating GPU block map vs CPU golden ---" << std::endl;
{
@@ -378,20 +418,16 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
{
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
{
if(block_map_gpu(b, h, qb, kb) !=
block_relation_onehot(b, h, qb, kb))
if(block_map_gpu(b, h, qb, kb) != block_relation_onehot(b, h, qb, kb))
{
bmap_mismatches++;
if(bmap_mismatches <= 10)
{
std::cout
<< " block_map mismatch at [" << b << "," << h << ","
<< qb << "," << kb
<< "]: GPU="
<< static_cast<int>(block_map_gpu(b, h, qb, kb))
<< " CPU="
<< static_cast<int>(
block_relation_onehot(b, h, qb, kb))
<< " block_map mismatch at [" << b << "," << h << "," << qb
<< "," << kb << "]: GPU="
<< static_cast<int>(block_map_gpu(b, h, qb, kb)) << " CPU="
<< static_cast<int>(block_relation_onehot(b, h, qb, kb))
<< std::endl;
}
}
@@ -429,28 +465,24 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
valid_mismatches++;
if(valid_mismatches <= 5)
{
std::cout
<< " valid_block_num mismatch at [" << b << "," << h
<< "," << qb
<< "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb)
<< " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb)
<< std::endl;
std::cout << " valid_block_num mismatch at [" << b << "," << h
<< "," << qb
<< "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb)
<< " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb)
<< std::endl;
}
}
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
{
if(vsa_lut_gpu.lut(b, h, qb, kb) !=
vsa_lut_cpu.lut(b, h, qb, kb))
if(vsa_lut_gpu.lut(b, h, qb, kb) != vsa_lut_cpu.lut(b, h, qb, kb))
{
lut_mismatches++;
if(lut_mismatches <= 10)
{
std::cout
<< " LUT mismatch at [" << b << "," << h << "," << qb
<< "," << kb
<< "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb)
<< " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb)
<< std::endl;
<< "," << kb << "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb)
<< " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb) << std::endl;
}
}
}

View File

@@ -1,195 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "vsa_sparge_attention.h"
#include "fmha_fwd_trek.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/device_memory.hpp"
#include <type_traits>
template <typename DataType_>
ck_tile::HostTensor<DataType_>
vsa_sparge_attention(const ck_tile::HostTensor<DataType_>& TQ,
const ck_tile::HostTensor<DataType_>& TK,
const ck_tile::HostTensor<DataType_>& TV,
const ck_tile::HostTensor<int32_t>& TKV_block_idx,
const ck_tile::HostTensor<int32_t>& TKV_blocks,
ck_tile::HostTensor<DataType_>& Y,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
bool i_perm,
bool o_perm,
int max_seqlen_q,
int max_seqlen_k,
int log_level)
{
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
std::is_same_v<DataType_, ck_tile::bf16_t>,
"VSA sparse attention supports fp16/bf16 only.");
std::string data_type = "fp16";
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
{
data_type = "bf16";
}
if(max_seqlen_q == 0)
max_seqlen_q = seqlen_q;
if(max_seqlen_k == 0)
max_seqlen_k = seqlen_k;
bool is_v_rowmajor = true;
float scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
std::string msk_str = "0";
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
const ck_tile::index_t shape_seqlen_q = seqlen_q;
const ck_tile::index_t shape_seqlen_k = seqlen_k;
ck_tile::stream_config stream_config{nullptr,
false, // time_kernel
log_level,
0,
1,
false};
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes());
ck_tile::DeviceMem lut_buf(TKV_block_idx.get_element_space_size_in_bytes());
ck_tile::DeviceMem valid_block_num_buf(TKV_blocks.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes());
q_buf.ToDevice(TQ.data());
k_buf.ToDevice(TK.data());
v_buf.ToDevice(TV.data());
lut_buf.ToDevice(TKV_block_idx.data());
valid_block_num_buf.ToDevice(TKV_blocks.data());
const auto init_args = [&](auto& args) {
assert(nhead % nhead_k == 0);
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
}();
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q;
const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
else
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
}();
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q;
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k;
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
args.v_ptr = v_buf.GetDeviceBuffer();
args.lut_ptr = lut_buf.GetDeviceBuffer();
args.valid_block_num_ptr = valid_block_num_buf.GetDeviceBuffer();
args.batch = batch;
args.seqlen_q = shape_seqlen_q;
args.hdim_q = hdim_q;
args.hdim_v = hdim_v;
args.nhead_q = nhead;
args.nhead_k = nhead_k;
args.stride_q = stride_q;
args.stride_k = stride_k;
args.stride_v = stride_v;
args.nhead_stride_q = nhead_stride_q;
args.nhead_stride_k = nhead_stride_k;
args.nhead_stride_v = nhead_stride_v;
args.batch_stride_q = batch_stride_q;
args.batch_stride_k = batch_stride_k;
args.batch_stride_v = batch_stride_v;
args.o_ptr = o_buf.GetDeviceBuffer();
args.seqlen_k = shape_seqlen_k;
args.max_seqlen_q = max_seqlen_q;
args.scale_s = scale_s;
args.stride_o = stride_o;
args.nhead_stride_o = nhead_stride_o;
args.batch_stride_o = batch_stride_o;
args.window_size_left = mask.left;
args.window_size_right = mask.right;
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
};
const auto init_traits = [&](auto& traits) {
traits.hdim_q = hdim_q;
traits.hdim_v = hdim_v;
traits.data_type = data_type;
traits.is_v_rowmajor = is_v_rowmajor;
traits.mask_type = mask.type;
};
fmha_vsa_fwd_traits fmha_traits;
init_traits(fmha_traits);
fmha_vsa_fwd_args args;
init_args(args);
sparge_vsa_fwd(fmha_traits, args, stream_config);
o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes());
return Y;
}
template ck_tile::HostTensor<ck_tile::half_t>
vsa_sparge_attention<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<int32_t>&,
const ck_tile::HostTensor<int32_t>&,
ck_tile::HostTensor<ck_tile::half_t>&,
int,
int,
int,
int,
int,
int,
int,
bool,
bool,
int,
int,
int);
template ck_tile::HostTensor<ck_tile::bf16_t>
vsa_sparge_attention<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<int32_t>&,
const ck_tile::HostTensor<int32_t>&,
ck_tile::HostTensor<ck_tile::bf16_t>&,
int,
int,
int,
int,
int,
int,
int,
bool,
bool,
int,
int,
int);

View File

@@ -1,28 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <optional>
#include <cstdint>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
template <typename DataType_>
ck_tile::HostTensor<DataType_>
vsa_sparge_attention(const ck_tile::HostTensor<DataType_>& TQ,
const ck_tile::HostTensor<DataType_>& TK,
const ck_tile::HostTensor<DataType_>& TV,
const ck_tile::HostTensor<int32_t>& TKV_block_idx,
const ck_tile::HostTensor<int32_t>& TKV_blocks,
ck_tile::HostTensor<DataType_>& Y,
int batch,
int nhead,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
bool i_perm,
bool o_perm,
int max_seqlen_q,
int max_seqlen_k,
int log_level = 0);