mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
fix extra host side operations.
This commit is contained in:
@@ -249,14 +249,12 @@ set(SPARGE_VSA_INSTANCES "tile_sparge_vsa_instances")
|
||||
|
||||
add_library(${SPARGE_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL
|
||||
${SPARGE_VSA_GEN_BLOBS}
|
||||
${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp
|
||||
)
|
||||
target_include_directories(${SPARGE_VSA_INSTANCES} PRIVATE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
|
||||
)
|
||||
set_source_files_properties(${SPARGE_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
|
||||
set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp PROPERTIES LANGUAGE HIP)
|
||||
set_property(TARGET ${SPARGE_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE
|
||||
@@ -273,7 +271,6 @@ set(SPARGE_BLOCKMAP_INSTANCES "tile_sparge_blockmap_instances")
|
||||
|
||||
add_library(${SPARGE_BLOCKMAP_INSTANCES} OBJECT EXCLUDE_FROM_ALL
|
||||
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp
|
||||
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp
|
||||
)
|
||||
target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
@@ -281,7 +278,6 @@ target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE
|
||||
)
|
||||
set_source_files_properties(
|
||||
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp
|
||||
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp
|
||||
PROPERTIES LANGUAGE HIP
|
||||
)
|
||||
set_property(TARGET ${SPARGE_BLOCKMAP_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "sparge_blockmap.h"
|
||||
#include "sparge_blockmap_trek.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include <type_traits>
|
||||
#include <cmath>
|
||||
|
||||
template <typename DataType_>
|
||||
sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
ck_tile::HostTensor<uint8_t>& block_map_out,
|
||||
int batch,
|
||||
int nhead_q,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
bool i_perm,
|
||||
float simthreshd1,
|
||||
float cdfthreshd,
|
||||
float topk,
|
||||
int blkq,
|
||||
int blkk,
|
||||
int log_level)
|
||||
{
|
||||
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
|
||||
std::is_same_v<DataType_, ck_tile::bf16_t>,
|
||||
"sparge_blockmap_gpu supports fp16/bf16 only.");
|
||||
|
||||
std::string data_type = "fp16";
|
||||
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
|
||||
{
|
||||
data_type = "bf16";
|
||||
}
|
||||
|
||||
const ck_tile::index_t num_q_blocks = ck_tile::integer_divide_ceil(seqlen_q, blkq);
|
||||
const ck_tile::index_t num_k_blocks = ck_tile::integer_divide_ceil(seqlen_k, blkk);
|
||||
|
||||
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
|
||||
|
||||
const std::size_t bmap_bytes =
|
||||
static_cast<std::size_t>(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(uint8_t);
|
||||
const std::size_t lut_bytes =
|
||||
static_cast<std::size_t>(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(int32_t);
|
||||
const std::size_t valid_bytes =
|
||||
static_cast<std::size_t>(batch) * nhead_q * num_q_blocks * sizeof(int32_t);
|
||||
|
||||
ck_tile::DeviceMem bmap_buf(bmap_bytes);
|
||||
ck_tile::DeviceMem lut_buf(lut_bytes);
|
||||
ck_tile::DeviceMem valid_buf(valid_bytes);
|
||||
|
||||
q_buf.ToDevice(TQ.data());
|
||||
k_buf.ToDevice(TK.data());
|
||||
bmap_buf.SetZero();
|
||||
lut_buf.SetZero();
|
||||
valid_buf.SetZero();
|
||||
|
||||
// Compute strides (assumes BHSD if i_perm, BSHD otherwise)
|
||||
const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead_q * hdim_q;
|
||||
const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q;
|
||||
const ck_tile::index_t nhead_stride_q =
|
||||
i_perm ? static_cast<ck_tile::index_t>(seqlen_q) * hdim_q : hdim_q;
|
||||
const ck_tile::index_t nhead_stride_k =
|
||||
i_perm ? static_cast<ck_tile::index_t>(seqlen_k) * hdim_q : hdim_q;
|
||||
const ck_tile::index_t batch_stride_q =
|
||||
static_cast<ck_tile::index_t>(nhead_q) * seqlen_q * hdim_q;
|
||||
const ck_tile::index_t batch_stride_k =
|
||||
static_cast<ck_tile::index_t>(nhead_k) * seqlen_k * hdim_q;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr, false, log_level, 0, 1, false};
|
||||
|
||||
sparge_blockmap_args args;
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.batch = batch;
|
||||
args.seqlen_q = seqlen_q;
|
||||
args.seqlen_k = seqlen_k;
|
||||
args.hdim_q = hdim_q;
|
||||
args.nhead_q = nhead_q;
|
||||
args.nhead_k = nhead_k;
|
||||
args.stride_q = stride_q;
|
||||
args.stride_k = stride_k;
|
||||
args.nhead_stride_q = nhead_stride_q;
|
||||
args.nhead_stride_k = nhead_stride_k;
|
||||
args.batch_stride_q = batch_stride_q;
|
||||
args.batch_stride_k = batch_stride_k;
|
||||
args.simthreshd1 = simthreshd1;
|
||||
args.cdfthreshd = cdfthreshd;
|
||||
args.topk = topk;
|
||||
args.scale = scale;
|
||||
args.block_map_ptr = bmap_buf.GetDeviceBuffer();
|
||||
args.lut_ptr = lut_buf.GetDeviceBuffer();
|
||||
args.valid_block_num_ptr = valid_buf.GetDeviceBuffer();
|
||||
|
||||
sparge_blockmap_traits traits;
|
||||
traits.data_type = data_type;
|
||||
traits.hdim_q = hdim_q;
|
||||
|
||||
sparge_blockmap_fwd(traits, args, stream_config);
|
||||
|
||||
// Copy results back to host
|
||||
bmap_buf.FromDevice(block_map_out.data(), bmap_bytes);
|
||||
|
||||
sparge::VSALut vsa_lut{
|
||||
ck_tile::HostTensor<int32_t>({batch, nhead_q, num_q_blocks, num_k_blocks}),
|
||||
ck_tile::HostTensor<int32_t>({batch, nhead_q, num_q_blocks}),
|
||||
};
|
||||
lut_buf.FromDevice(vsa_lut.lut.data(), lut_bytes);
|
||||
valid_buf.FromDevice(vsa_lut.valid_block_num.data(), valid_bytes);
|
||||
|
||||
return vsa_lut;
|
||||
}
|
||||
|
||||
// Explicit template instantiations
|
||||
template sparge::VSALut
|
||||
sparge_blockmap_gpu<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
ck_tile::HostTensor<uint8_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
|
||||
template sparge::VSALut
|
||||
sparge_blockmap_gpu<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
ck_tile::HostTensor<uint8_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
float,
|
||||
float,
|
||||
float,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
@@ -1,26 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "sparge_tool.hpp"
|
||||
|
||||
template <typename DataType_>
|
||||
sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
ck_tile::HostTensor<uint8_t>& block_map_out,
|
||||
int batch,
|
||||
int nhead_q,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
bool i_perm,
|
||||
float simthreshd1,
|
||||
float cdfthreshd,
|
||||
float topk,
|
||||
int blkq,
|
||||
int blkk,
|
||||
int log_level = 0);
|
||||
@@ -1,23 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention
|
||||
// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention (all-in-device)
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <chrono>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/reference/reference_blocked_attention.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
|
||||
#include "vsa_sparge_attention.h"
|
||||
#include "sparge_blockmap.h"
|
||||
#include "sparge_blockmap_trek.hpp"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "sparge_tool.hpp"
|
||||
|
||||
// ============================================================================
|
||||
@@ -192,7 +186,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
<< ", topk=" << topk << ")" << std::endl;
|
||||
std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl;
|
||||
|
||||
// Create host tensors
|
||||
// Create host tensors and fill with random data
|
||||
ck_tile::HostTensor<T> q_host = make_qkv_tensor<T>(batch, nhead, seqlen_q, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> k_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_q, i_perm);
|
||||
ck_tile::HostTensor<T> v_host = make_qkv_tensor<T>(batch, nhead_k, seqlen_k, hdim_v, i_perm);
|
||||
@@ -206,119 +200,157 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
|
||||
|
||||
// ==================================================================
|
||||
// GPU: Build block map + VSA LUT in one kernel (always run)
|
||||
// Allocate device memory once, HtoD once
|
||||
// ==================================================================
|
||||
std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl;
|
||||
ck_tile::HostTensor<uint8_t> block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks});
|
||||
auto vsa_lut_gpu = sparge_blockmap_gpu<T>(q_host,
|
||||
k_host,
|
||||
block_map_gpu,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
i_perm,
|
||||
simthreshd1,
|
||||
cdfthreshd,
|
||||
topk,
|
||||
static_cast<int>(BLKQ),
|
||||
static_cast<int>(BLKK),
|
||||
0);
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(output_host.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
v_buf.ToDevice(v_host.data());
|
||||
|
||||
const std::size_t bmap_bytes =
|
||||
static_cast<std::size_t>(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(uint8_t);
|
||||
const std::size_t lut_bytes =
|
||||
static_cast<std::size_t>(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(int32_t);
|
||||
const std::size_t valid_bytes =
|
||||
static_cast<std::size_t>(batch) * nhead * num_q_blocks * sizeof(int32_t);
|
||||
|
||||
ck_tile::DeviceMem bmap_buf(bmap_bytes);
|
||||
ck_tile::DeviceMem lut_buf(lut_bytes);
|
||||
ck_tile::DeviceMem valid_buf(valid_bytes);
|
||||
bmap_buf.SetZero();
|
||||
lut_buf.SetZero();
|
||||
valid_buf.SetZero();
|
||||
|
||||
// ==================================================================
|
||||
// VSA sparse attention kernel (always run)
|
||||
// Common stride calculations
|
||||
// ==================================================================
|
||||
assert(nhead % nhead_k == 0);
|
||||
const float scale_s = 1.0f / std::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead * hdim_q;
|
||||
const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q;
|
||||
const ck_tile::index_t stride_v = i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
const ck_tile::index_t stride_o = o_perm ? hdim_v : nhead * hdim_v;
|
||||
const ck_tile::index_t nhead_stride_q = i_perm ? seqlen_q * hdim_q : hdim_q;
|
||||
const ck_tile::index_t nhead_stride_k = i_perm ? seqlen_k * hdim_q : hdim_q;
|
||||
const ck_tile::index_t nhead_stride_v = i_perm ? seqlen_k * hdim_v : hdim_v;
|
||||
const ck_tile::index_t nhead_stride_o = o_perm ? seqlen_q * hdim_v : hdim_v;
|
||||
const ck_tile::index_t batch_stride_q = nhead * seqlen_q * hdim_q;
|
||||
const ck_tile::index_t batch_stride_k = nhead_k * seqlen_k * hdim_q;
|
||||
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * seqlen_k;
|
||||
const ck_tile::index_t batch_stride_o = nhead * seqlen_q * hdim_v;
|
||||
|
||||
std::string data_type = "fp16";
|
||||
if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
data_type = "bf16";
|
||||
|
||||
std::string msk_str = "0";
|
||||
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
|
||||
|
||||
// ==================================================================
|
||||
// GPU: Build block map + VSA LUT (always run, device-only)
|
||||
// ==================================================================
|
||||
std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl;
|
||||
{
|
||||
sparge_blockmap_args args;
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.batch = batch;
|
||||
args.seqlen_q = seqlen_q;
|
||||
args.seqlen_k = seqlen_k;
|
||||
args.hdim_q = hdim_q;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead_k;
|
||||
args.stride_q = stride_q;
|
||||
args.stride_k = stride_k;
|
||||
args.nhead_stride_q = nhead_stride_q;
|
||||
args.nhead_stride_k = nhead_stride_k;
|
||||
args.batch_stride_q = batch_stride_q;
|
||||
args.batch_stride_k = batch_stride_k;
|
||||
args.simthreshd1 = simthreshd1;
|
||||
args.cdfthreshd = cdfthreshd;
|
||||
args.topk = topk;
|
||||
args.scale = scale_s;
|
||||
args.block_map_ptr = bmap_buf.GetDeviceBuffer();
|
||||
args.lut_ptr = lut_buf.GetDeviceBuffer();
|
||||
args.valid_block_num_ptr = valid_buf.GetDeviceBuffer();
|
||||
|
||||
sparge_blockmap_traits traits;
|
||||
traits.data_type = data_type;
|
||||
traits.hdim_q = hdim_q;
|
||||
|
||||
sparge_blockmap_fwd(traits, args, ck_tile::stream_config{});
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// VSA sparse attention kernel (always run, LUT stays on device)
|
||||
// ==================================================================
|
||||
std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl;
|
||||
|
||||
try
|
||||
{
|
||||
if(kname)
|
||||
{
|
||||
vsa_sparge_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
vsa_lut_gpu.lut,
|
||||
vsa_lut_gpu.valid_block_num,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
1);
|
||||
}
|
||||
fmha_vsa_fwd_args fmha_args;
|
||||
fmha_args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
fmha_args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
fmha_args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
fmha_args.lut_ptr = lut_buf.GetDeviceBuffer();
|
||||
fmha_args.valid_block_num_ptr = valid_buf.GetDeviceBuffer();
|
||||
fmha_args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
fmha_args.batch = batch;
|
||||
fmha_args.seqlen_q = seqlen_q;
|
||||
fmha_args.seqlen_k = seqlen_k;
|
||||
fmha_args.max_seqlen_q = seqlen_q;
|
||||
fmha_args.hdim_q = hdim_q;
|
||||
fmha_args.hdim_v = hdim_v;
|
||||
fmha_args.nhead_q = nhead;
|
||||
fmha_args.nhead_k = nhead_k;
|
||||
fmha_args.scale_s = scale_s;
|
||||
fmha_args.stride_q = stride_q;
|
||||
fmha_args.stride_k = stride_k;
|
||||
fmha_args.stride_v = stride_v;
|
||||
fmha_args.stride_o = stride_o;
|
||||
fmha_args.nhead_stride_q = nhead_stride_q;
|
||||
fmha_args.nhead_stride_k = nhead_stride_k;
|
||||
fmha_args.nhead_stride_v = nhead_stride_v;
|
||||
fmha_args.nhead_stride_o = nhead_stride_o;
|
||||
fmha_args.batch_stride_q = batch_stride_q;
|
||||
fmha_args.batch_stride_k = batch_stride_k;
|
||||
fmha_args.batch_stride_v = batch_stride_v;
|
||||
fmha_args.batch_stride_o = batch_stride_o;
|
||||
fmha_args.window_size_left = mask.left;
|
||||
fmha_args.window_size_right = mask.right;
|
||||
fmha_args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
||||
|
||||
for(int i = 0; i < warmup; ++i)
|
||||
{
|
||||
vsa_sparge_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
vsa_lut_gpu.lut,
|
||||
vsa_lut_gpu.valid_block_num,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
0);
|
||||
}
|
||||
fmha_vsa_fwd_traits fmha_traits;
|
||||
fmha_traits.hdim_q = hdim_q;
|
||||
fmha_traits.hdim_v = hdim_v;
|
||||
fmha_traits.data_type = data_type;
|
||||
fmha_traits.is_v_rowmajor = true;
|
||||
fmha_traits.mask_type = mask.type;
|
||||
|
||||
[[maybe_unused]] auto sync_status1 = hipDeviceSynchronize();
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/* log_level = */ kname ? 1 : 0,
|
||||
warmup,
|
||||
repeat,
|
||||
false};
|
||||
|
||||
for(int i = 0; i < repeat; ++i)
|
||||
{
|
||||
vsa_sparge_attention<T>(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
vsa_lut_gpu.lut,
|
||||
vsa_lut_gpu.valid_block_num,
|
||||
output_host,
|
||||
batch,
|
||||
nhead,
|
||||
nhead_k,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
i_perm,
|
||||
o_perm,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
0);
|
||||
}
|
||||
float avg_time_ms = sparge_vsa_fwd(fmha_traits, fmha_args, stream_config);
|
||||
|
||||
[[maybe_unused]] auto sync_status2 = hipDeviceSynchronize();
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
double avg_time_ms =
|
||||
std::chrono::duration<double, std::milli>(end - start).count() / repeat;
|
||||
std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<"
|
||||
<< std::endl;
|
||||
|
||||
std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<"
|
||||
<< std::endl;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error during kernel execution: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
// DtoH: attention output (always needed)
|
||||
o_buf.FromDevice(output_host.data(), output_host.get_element_space_size_in_bytes());
|
||||
|
||||
// DtoH: block_map (needed for sparsity stats and validation)
|
||||
ck_tile::HostTensor<uint8_t> block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks});
|
||||
bmap_buf.FromDevice(block_map_gpu.data(), bmap_bytes);
|
||||
|
||||
// ==================================================================
|
||||
// Sparsity statistics (always run, pure CPU read of HostTensor)
|
||||
// Sparsity statistics (pure CPU, reads block_map HostTensor)
|
||||
// ==================================================================
|
||||
std::size_t total_blocks = 0;
|
||||
std::size_t active_blocks = 0;
|
||||
@@ -366,6 +398,14 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << "Converting block map to VSA LUT (delta, CPU)..." << std::endl;
|
||||
auto vsa_lut_cpu = sparge::block_map_to_vsa_lut_delta(block_relation_onehot);
|
||||
|
||||
// DtoH: LUT + valid_block_num (only for validation)
|
||||
sparge::VSALut vsa_lut_gpu{
|
||||
ck_tile::HostTensor<int32_t>({batch, nhead, num_q_blocks, num_k_blocks}),
|
||||
ck_tile::HostTensor<int32_t>({batch, nhead, num_q_blocks}),
|
||||
};
|
||||
lut_buf.FromDevice(vsa_lut_gpu.lut.data(), lut_bytes);
|
||||
valid_buf.FromDevice(vsa_lut_gpu.valid_block_num.data(), valid_bytes);
|
||||
|
||||
// Validate block map
|
||||
std::cout << "\n--- Validating GPU block map vs CPU golden ---" << std::endl;
|
||||
{
|
||||
@@ -378,20 +418,16 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
if(block_map_gpu(b, h, qb, kb) !=
|
||||
block_relation_onehot(b, h, qb, kb))
|
||||
if(block_map_gpu(b, h, qb, kb) != block_relation_onehot(b, h, qb, kb))
|
||||
{
|
||||
bmap_mismatches++;
|
||||
if(bmap_mismatches <= 10)
|
||||
{
|
||||
std::cout
|
||||
<< " block_map mismatch at [" << b << "," << h << ","
|
||||
<< qb << "," << kb
|
||||
<< "]: GPU="
|
||||
<< static_cast<int>(block_map_gpu(b, h, qb, kb))
|
||||
<< " CPU="
|
||||
<< static_cast<int>(
|
||||
block_relation_onehot(b, h, qb, kb))
|
||||
<< " block_map mismatch at [" << b << "," << h << "," << qb
|
||||
<< "," << kb << "]: GPU="
|
||||
<< static_cast<int>(block_map_gpu(b, h, qb, kb)) << " CPU="
|
||||
<< static_cast<int>(block_relation_onehot(b, h, qb, kb))
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
@@ -429,28 +465,24 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
valid_mismatches++;
|
||||
if(valid_mismatches <= 5)
|
||||
{
|
||||
std::cout
|
||||
<< " valid_block_num mismatch at [" << b << "," << h
|
||||
<< "," << qb
|
||||
<< "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb)
|
||||
<< " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb)
|
||||
<< std::endl;
|
||||
std::cout << " valid_block_num mismatch at [" << b << "," << h
|
||||
<< "," << qb
|
||||
<< "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb)
|
||||
<< " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb)
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
|
||||
{
|
||||
if(vsa_lut_gpu.lut(b, h, qb, kb) !=
|
||||
vsa_lut_cpu.lut(b, h, qb, kb))
|
||||
if(vsa_lut_gpu.lut(b, h, qb, kb) != vsa_lut_cpu.lut(b, h, qb, kb))
|
||||
{
|
||||
lut_mismatches++;
|
||||
if(lut_mismatches <= 10)
|
||||
{
|
||||
std::cout
|
||||
<< " LUT mismatch at [" << b << "," << h << "," << qb
|
||||
<< "," << kb
|
||||
<< "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb)
|
||||
<< " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb)
|
||||
<< std::endl;
|
||||
<< "," << kb << "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb)
|
||||
<< " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb) << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,195 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#include "vsa_sparge_attention.h"
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
vsa_sparge_attention(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
const ck_tile::HostTensor<DataType_>& TV,
|
||||
const ck_tile::HostTensor<int32_t>& TKV_block_idx,
|
||||
const ck_tile::HostTensor<int32_t>& TKV_blocks,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
int log_level)
|
||||
{
|
||||
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
|
||||
std::is_same_v<DataType_, ck_tile::bf16_t>,
|
||||
"VSA sparse attention supports fp16/bf16 only.");
|
||||
std::string data_type = "fp16";
|
||||
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
|
||||
{
|
||||
data_type = "bf16";
|
||||
}
|
||||
|
||||
if(max_seqlen_q == 0)
|
||||
max_seqlen_q = seqlen_q;
|
||||
if(max_seqlen_k == 0)
|
||||
max_seqlen_k = seqlen_k;
|
||||
bool is_v_rowmajor = true;
|
||||
float scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
|
||||
std::string msk_str = "0";
|
||||
mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k);
|
||||
|
||||
const ck_tile::index_t shape_seqlen_q = seqlen_q;
|
||||
const ck_tile::index_t shape_seqlen_k = seqlen_k;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
false, // time_kernel
|
||||
log_level,
|
||||
0,
|
||||
1,
|
||||
false};
|
||||
|
||||
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lut_buf(TKV_block_idx.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem valid_block_num_buf(TKV_blocks.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(TQ.data());
|
||||
k_buf.ToDevice(TK.data());
|
||||
v_buf.ToDevice(TV.data());
|
||||
lut_buf.ToDevice(TKV_block_idx.data());
|
||||
valid_block_num_buf.ToDevice(TKV_blocks.data());
|
||||
|
||||
const auto init_args = [&](auto& args) {
|
||||
assert(nhead % nhead_k == 0);
|
||||
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
|
||||
const ck_tile::index_t stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? hdim_v : nhead_k * hdim_v;
|
||||
else
|
||||
return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k);
|
||||
}();
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q;
|
||||
const ck_tile::index_t nhead_stride_v = [&]() {
|
||||
if(is_v_rowmajor)
|
||||
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
|
||||
else
|
||||
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
|
||||
}();
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q;
|
||||
const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k;
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
args.lut_ptr = lut_buf.GetDeviceBuffer();
|
||||
args.valid_block_num_ptr = valid_block_num_buf.GetDeviceBuffer();
|
||||
|
||||
args.batch = batch;
|
||||
args.seqlen_q = shape_seqlen_q;
|
||||
args.hdim_q = hdim_q;
|
||||
args.hdim_v = hdim_v;
|
||||
args.nhead_q = nhead;
|
||||
args.nhead_k = nhead_k;
|
||||
|
||||
args.stride_q = stride_q;
|
||||
args.stride_k = stride_k;
|
||||
args.stride_v = stride_v;
|
||||
args.nhead_stride_q = nhead_stride_q;
|
||||
args.nhead_stride_k = nhead_stride_k;
|
||||
args.nhead_stride_v = nhead_stride_v;
|
||||
args.batch_stride_q = batch_stride_q;
|
||||
args.batch_stride_k = batch_stride_k;
|
||||
args.batch_stride_v = batch_stride_v;
|
||||
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
|
||||
args.seqlen_k = shape_seqlen_k;
|
||||
args.max_seqlen_q = max_seqlen_q;
|
||||
|
||||
args.scale_s = scale_s;
|
||||
|
||||
args.stride_o = stride_o;
|
||||
args.nhead_stride_o = nhead_stride_o;
|
||||
args.batch_stride_o = batch_stride_o;
|
||||
|
||||
args.window_size_left = mask.left;
|
||||
args.window_size_right = mask.right;
|
||||
args.mask_type = static_cast<ck_tile::index_t>(mask.type);
|
||||
};
|
||||
|
||||
const auto init_traits = [&](auto& traits) {
|
||||
traits.hdim_q = hdim_q;
|
||||
traits.hdim_v = hdim_v;
|
||||
traits.data_type = data_type;
|
||||
traits.is_v_rowmajor = is_v_rowmajor;
|
||||
traits.mask_type = mask.type;
|
||||
};
|
||||
|
||||
fmha_vsa_fwd_traits fmha_traits;
|
||||
init_traits(fmha_traits);
|
||||
|
||||
fmha_vsa_fwd_args args;
|
||||
init_args(args);
|
||||
|
||||
sparge_vsa_fwd(fmha_traits, args, stream_config);
|
||||
|
||||
o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes());
|
||||
|
||||
return Y;
|
||||
}
|
||||
|
||||
template ck_tile::HostTensor<ck_tile::half_t>
|
||||
vsa_sparge_attention<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
const ck_tile::HostTensor<int32_t>&,
|
||||
const ck_tile::HostTensor<int32_t>&,
|
||||
ck_tile::HostTensor<ck_tile::half_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
|
||||
template ck_tile::HostTensor<ck_tile::bf16_t>
|
||||
vsa_sparge_attention<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
const ck_tile::HostTensor<int32_t>&,
|
||||
const ck_tile::HostTensor<int32_t>&,
|
||||
ck_tile::HostTensor<ck_tile::bf16_t>&,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
int,
|
||||
bool,
|
||||
bool,
|
||||
int,
|
||||
int,
|
||||
int);
|
||||
@@ -1,28 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include <optional>
|
||||
#include <cstdint>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
template <typename DataType_>
|
||||
ck_tile::HostTensor<DataType_>
|
||||
vsa_sparge_attention(const ck_tile::HostTensor<DataType_>& TQ,
|
||||
const ck_tile::HostTensor<DataType_>& TK,
|
||||
const ck_tile::HostTensor<DataType_>& TV,
|
||||
const ck_tile::HostTensor<int32_t>& TKV_block_idx,
|
||||
const ck_tile::HostTensor<int32_t>& TKV_blocks,
|
||||
ck_tile::HostTensor<DataType_>& Y,
|
||||
int batch,
|
||||
int nhead,
|
||||
int nhead_k,
|
||||
int seqlen_q,
|
||||
int seqlen_k,
|
||||
int hdim_q,
|
||||
int hdim_v,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
int max_seqlen_q,
|
||||
int max_seqlen_k,
|
||||
int log_level = 0);
|
||||
Reference in New Issue
Block a user