Add sparge gpu pipeline in tile_example_sparge_vsa_sparse_attn

This commit is contained in:
Gino Lu
2026-04-13 03:34:08 -04:00
parent 643ad35de2
commit d1d457b82a
8 changed files with 1295 additions and 50 deletions

View File

@@ -266,11 +266,41 @@ target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE
-Wno-float-equal
)
# Sparge + VSA Example executable
# ============================================================================
# Sparge BlockMap GPU Kernel (hand-written instantiation, no codegen)
# ============================================================================
set(SPARGE_BLOCKMAP_INSTANCES "tile_sparge_blockmap_instances")
add_library(${SPARGE_BLOCKMAP_INSTANCES} OBJECT EXCLUDE_FROM_ALL
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp
)
target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE
${CMAKE_CURRENT_LIST_DIR}
${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn
)
set_source_files_properties(
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp
${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp
PROPERTIES LANGUAGE HIP
)
set_property(TARGET ${SPARGE_BLOCKMAP_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
target_compile_options(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE
-DCK_TILE_USE_BUFFER_ADDRESSING_BUILTIN
-DCK_TILE_FMHA_FWD_FAST_EXP2
-Wno-undefined-func-template
-Wno-float-equal
)
# Sparge + VSA Example executable (now links blockmap kernel too)
set(EXAMPLE_SPARGE_VSA_SPARSE_ATTN "tile_example_sparge_vsa_sparse_attn")
message(DEBUG "adding example ${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}")
add_executable(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} EXCLUDE_FROM_ALL test_sparge_vsa_sparse_attn.cpp)
target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} ${SPARGE_VSA_INSTANCES})
target_link_libraries(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN}
${SPARGE_VSA_INSTANCES}
${SPARGE_BLOCKMAP_INSTANCES}
)
target_include_directories(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_compile_options(${EXAMPLE_SPARGE_VSA_SPARSE_ATTN} PRIVATE
-Wno-undefined-func-template

View File

@@ -0,0 +1,156 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "sparge_blockmap.h"
#include "sparge_blockmap_trek.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/device_memory.hpp"
#include <type_traits>
#include <cmath>
template <typename DataType_>
sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor<DataType_>& TQ,
const ck_tile::HostTensor<DataType_>& TK,
ck_tile::HostTensor<uint8_t>& block_map_out,
int batch,
int nhead_q,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
bool i_perm,
float simthreshd1,
float cdfthreshd,
float topk,
int blkq,
int blkk,
int log_level)
{
static_assert(std::is_same_v<DataType_, ck_tile::half_t> ||
std::is_same_v<DataType_, ck_tile::bf16_t>,
"sparge_blockmap_gpu supports fp16/bf16 only.");
std::string data_type = "fp16";
if constexpr(std::is_same_v<DataType_, ck_tile::bf16_t>)
{
data_type = "bf16";
}
const ck_tile::index_t num_q_blocks = ck_tile::integer_divide_ceil(seqlen_q, blkq);
const ck_tile::index_t num_k_blocks = ck_tile::integer_divide_ceil(seqlen_k, blkk);
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
// Allocate device memory
ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes());
const std::size_t bmap_bytes =
static_cast<std::size_t>(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(uint8_t);
const std::size_t lut_bytes =
static_cast<std::size_t>(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(int32_t);
const std::size_t valid_bytes =
static_cast<std::size_t>(batch) * nhead_q * num_q_blocks * sizeof(int32_t);
ck_tile::DeviceMem bmap_buf(bmap_bytes);
ck_tile::DeviceMem lut_buf(lut_bytes);
ck_tile::DeviceMem valid_buf(valid_bytes);
q_buf.ToDevice(TQ.data());
k_buf.ToDevice(TK.data());
bmap_buf.SetZero();
lut_buf.SetZero();
valid_buf.SetZero();
// Compute strides (assumes BHSD if i_perm, BSHD otherwise)
const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead_q * hdim_q;
const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q;
const ck_tile::index_t nhead_stride_q =
i_perm ? static_cast<ck_tile::index_t>(seqlen_q) * hdim_q : hdim_q;
const ck_tile::index_t nhead_stride_k =
i_perm ? static_cast<ck_tile::index_t>(seqlen_k) * hdim_q : hdim_q;
const ck_tile::index_t batch_stride_q =
static_cast<ck_tile::index_t>(nhead_q) * seqlen_q * hdim_q;
const ck_tile::index_t batch_stride_k =
static_cast<ck_tile::index_t>(nhead_k) * seqlen_k * hdim_q;
ck_tile::stream_config stream_config{nullptr, false, log_level, 0, 1, false};
sparge_blockmap_args args;
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
args.batch = batch;
args.seqlen_q = seqlen_q;
args.seqlen_k = seqlen_k;
args.hdim_q = hdim_q;
args.nhead_q = nhead_q;
args.nhead_k = nhead_k;
args.stride_q = stride_q;
args.stride_k = stride_k;
args.nhead_stride_q = nhead_stride_q;
args.nhead_stride_k = nhead_stride_k;
args.batch_stride_q = batch_stride_q;
args.batch_stride_k = batch_stride_k;
args.simthreshd1 = simthreshd1;
args.cdfthreshd = cdfthreshd;
args.topk = topk;
args.scale = scale;
args.block_map_ptr = bmap_buf.GetDeviceBuffer();
args.lut_ptr = lut_buf.GetDeviceBuffer();
args.valid_block_num_ptr = valid_buf.GetDeviceBuffer();
sparge_blockmap_traits traits;
traits.data_type = data_type;
traits.hdim_q = hdim_q;
sparge_blockmap_fwd(traits, args, stream_config);
// Copy results back to host
bmap_buf.FromDevice(block_map_out.data(), bmap_bytes);
sparge::VSALut vsa_lut{
ck_tile::HostTensor<int32_t>({batch, nhead_q, num_q_blocks, num_k_blocks}),
ck_tile::HostTensor<int32_t>({batch, nhead_q, num_q_blocks}),
};
lut_buf.FromDevice(vsa_lut.lut.data(), lut_bytes);
valid_buf.FromDevice(vsa_lut.valid_block_num.data(), valid_bytes);
return vsa_lut;
}
// Explicit template instantiations
template sparge::VSALut
sparge_blockmap_gpu<ck_tile::half_t>(const ck_tile::HostTensor<ck_tile::half_t>&,
const ck_tile::HostTensor<ck_tile::half_t>&,
ck_tile::HostTensor<uint8_t>&,
int,
int,
int,
int,
int,
int,
bool,
float,
float,
float,
int,
int,
int);
template sparge::VSALut
sparge_blockmap_gpu<ck_tile::bf16_t>(const ck_tile::HostTensor<ck_tile::bf16_t>&,
const ck_tile::HostTensor<ck_tile::bf16_t>&,
ck_tile::HostTensor<uint8_t>&,
int,
int,
int,
int,
int,
int,
bool,
float,
float,
float,
int,
int,
int);

View File

@@ -0,0 +1,26 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdint>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "sparge_tool.hpp"
template <typename DataType_>
sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor<DataType_>& TQ,
const ck_tile::HostTensor<DataType_>& TK,
ck_tile::HostTensor<uint8_t>& block_map_out,
int batch,
int nhead_q,
int nhead_k,
int seqlen_q,
int seqlen_k,
int hdim_q,
bool i_perm,
float simthreshd1,
float cdfthreshd,
float topk,
int blkq,
int blkk,
int log_level = 0);

View File

@@ -0,0 +1,88 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Hand-written template instantiation for SpargeBlockMapKernel (fp16, D=128).
#include "sparge_blockmap_trek.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include <iostream>
// ============================================================================
// Type configuration for block map kernel (reuses FmhaSparseFwdTypeConfig)
// ============================================================================
// fp16: D=128, kM0=64, kN0=128
using bmap_fp16_block_tile = ck_tile::sequence<64, 128, 128, 128, 128, 128>;
// kM0 kN0 kK0 kN1 kK1 kQKHeaddim(D)
using bmap_fp16_shape =
ck_tile::TileFmhaShape<bmap_fp16_block_tile,
ck_tile::sequence<4, 1, 1>, // Gemm0BlockWarps
ck_tile::sequence<16, 16, 16>, // Gemm0WarpTile (unused by blockmap, but
// needed by shape)
ck_tile::sequence<4, 1, 1>, // Gemm1BlockWarps
ck_tile::sequence<16, 16, 16>, // Gemm1WarpTile
true>; // VLayout row-major
using bmap_fp16_trait = ck_tile::TileFmhaTraits<true, // kPadSeqLenQ
true, // kPadSeqLenK
true, // kPadHeadDimQ
true, // kPadHeadDimV
false, // kHasLogitsSoftCap
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false, // kStoreLSE
false, // kHasDropout
false, // kHasRandVal
ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE,
-1, // kBlockPerCu
false>; // kIsVRowMajorSkip
using bmap_fp16_variant = ck_tile::ComposedAttention<0, CK_TILE_FMHA_FWD_FAST_EXP2>;
using bmap_fp16_mask = ck_tile::GenericAttentionMask<false>;
using bmap_fp16_problem = ck_tile::BlockFmhaPipelineProblem<ck_tile::half_t, // QDataType
ck_tile::half_t, // KDataType
ck_tile::half_t, // VDataType
float, // SaccDataType
float, // SMPLComputeDataType
ck_tile::half_t, // BiasDataType
uint8_t, // RandValOutputDataType
float, // LSEDataType
ck_tile::half_t, // PDataType
float, // OaccDataType
ck_tile::half_t, // ODataType
bmap_fp16_shape,
false, // kIsGroupMode
bmap_fp16_variant,
bmap_fp16_mask,
false, // kUseTrLoad
bmap_fp16_trait>;
using bmap_fp16_pipeline = ck_tile::SpargeBlockMapPipeline<bmap_fp16_problem>;
using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel<bmap_fp16_pipeline>;
// ============================================================================
// Dispatch
// ============================================================================
float sparge_blockmap_fwd(sparge_blockmap_traits traits,
sparge_blockmap_args args,
const ck_tile::stream_config& s)
{
if(traits.data_type == "fp16" && traits.hdim_q == 128)
{
using k_ = bmap_fp16_kernel;
if(s.log_level_ > 0)
std::cout << ", sparge_blockmap_fp16_d128" << std::flush;
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<k_>(args);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
}
if(s.log_level_ > 0)
std::cerr << "sparge_blockmap_fwd: unsupported config (data_type=" << traits.data_type
<< ", hdim_q=" << traits.hdim_q << ")" << std::endl;
return -1.f;
}

View File

@@ -0,0 +1,93 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp"
#include "ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp"
#include "fmha_fwd_trek.hpp"
#include <string>
#include <type_traits>
// ============================================================================
// Args and traits for sparge block map GPU kernel
// ============================================================================
struct sparge_blockmap_args
{
const void* q_ptr;
const void* k_ptr;
ck_tile::index_t batch;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
float simthreshd1;
float cdfthreshd;
float topk;
float scale;
void* block_map_ptr;
void* lut_ptr;
void* valid_block_num_ptr;
};
struct sparge_blockmap_traits
{
std::string data_type;
int hdim_q;
};
// ============================================================================
// Create kernel args and grid dimensions
// ============================================================================
template <typename BlockMapKernel>
auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = BlockMapKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.stride_q,
args.stride_k,
args.nhead_stride_q,
args.nhead_stride_k,
args.batch_stride_q,
args.batch_stride_k,
args.simthreshd1,
args.cdfthreshd,
args.topk,
args.scale,
args.block_map_ptr,
args.lut_ptr,
args.valid_block_num_ptr);
dim3 grids = BlockMapKernel::GridSize(args.batch, args.nhead_q, args.seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
// ============================================================================
// Hand-written template instantiation dispatch
// ============================================================================
float sparge_blockmap_fwd(sparge_blockmap_traits traits,
sparge_blockmap_args args,
const ck_tile::stream_config& stream_config);

View File

@@ -17,6 +17,7 @@
#include "ck_tile/core/utility/bit_cast.hpp"
#include "vsa_sparge_attention.h"
#include "sparge_blockmap.h"
#include "sparge_tool.hpp"
// ============================================================================
@@ -198,53 +199,37 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<T> output_host =
o_perm ? ck_tile::HostTensor<T>({batch, nhead, seqlen_q, hdim_v})
: ck_tile::HostTensor<T>({batch, seqlen_q, nhead, hdim_v});
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
std::cout << "\nInitializing tensors..." << std::endl;
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed}(q_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 1}(k_host);
ck_tile::FillUniformDistribution<T>{-0.5f, 0.5f, seed + 2}(v_host);
// Build block map using Sparge tool
std::cout << "Building Sparge block map..." << std::endl;
sparge::SpargeParams p;
p.BLKQ = static_cast<int>(BLKQ);
p.BLKK = static_cast<int>(BLKK);
p.simthreshd1 = simthreshd1;
p.cdfthreshd = cdfthreshd;
p.topk = topk;
p.i_perm = i_perm;
ck_tile::HostTensor<uint8_t> block_relation_onehot =
sparge::build_block_map_meansim(q_host, k_host, p);
// Convert to VSA LUT (delta-encoded) + valid_block_num
std::cout << "Converting block map to VSA LUT (delta)..." << std::endl;
auto vsa_lut = sparge::block_map_to_vsa_lut_delta(block_relation_onehot);
// Print actual sparsity (based on one-hot)
std::size_t total_blocks = 0;
std::size_t active_blocks = 0;
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
{
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
{
total_blocks++;
if(block_relation_onehot(b, h, qb, kb) != 0)
active_blocks++;
}
}
}
}
float actual_sparsity =
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
std::cout << " Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
<< total_blocks << " blocks active)" << std::endl;
// ==================================================================
// GPU: Build block map + VSA LUT in one kernel (always run)
// ==================================================================
std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl;
ck_tile::HostTensor<uint8_t> block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks});
auto vsa_lut_gpu = sparge_blockmap_gpu<T>(q_host,
k_host,
block_map_gpu,
batch,
nhead,
nhead_k,
seqlen_q,
seqlen_k,
hdim_q,
i_perm,
simthreshd1,
cdfthreshd,
topk,
static_cast<int>(BLKQ),
static_cast<int>(BLKK),
0);
// ==================================================================
// VSA sparse attention kernel (always run)
// ==================================================================
std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl;
try
@@ -254,8 +239,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
vsa_sparge_attention<T>(q_host,
k_host,
v_host,
vsa_lut.lut,
vsa_lut.valid_block_num,
vsa_lut_gpu.lut,
vsa_lut_gpu.valid_block_num,
output_host,
batch,
nhead,
@@ -276,8 +261,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
vsa_sparge_attention<T>(q_host,
k_host,
v_host,
vsa_lut.lut,
vsa_lut.valid_block_num,
vsa_lut_gpu.lut,
vsa_lut_gpu.valid_block_num,
output_host,
batch,
nhead,
@@ -301,8 +286,8 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
vsa_sparge_attention<T>(q_host,
k_host,
v_host,
vsa_lut.lut,
vsa_lut.valid_block_num,
vsa_lut_gpu.lut,
vsa_lut_gpu.valid_block_num,
output_host,
batch,
nhead,
@@ -332,17 +317,168 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
return false;
}
// ==================================================================
// Sparsity statistics (always run, pure CPU read of HostTensor)
// ==================================================================
std::size_t total_blocks = 0;
std::size_t active_blocks = 0;
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
{
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
{
total_blocks++;
if(block_map_gpu(b, h, qb, kb) != 0)
active_blocks++;
}
}
}
}
float actual_sparsity =
1.0f - static_cast<float>(active_blocks) / static_cast<float>(total_blocks);
std::cout << "\n Actual sparsity: " << actual_sparsity << " (" << active_blocks << "/"
<< total_blocks << " blocks active)" << std::endl;
// ==================================================================
// Validation (only when -v=1)
// ==================================================================
bool pass = true;
if(do_validation)
{
std::cout << "\n--- Performing CPU validation ---" << std::endl;
// CPU golden: block map + VSA LUT
std::cout << "Building Sparge block map (CPU golden)..." << std::endl;
sparge::SpargeParams p;
p.BLKQ = static_cast<int>(BLKQ);
p.BLKK = static_cast<int>(BLKK);
p.simthreshd1 = simthreshd1;
p.cdfthreshd = cdfthreshd;
p.topk = topk;
p.i_perm = i_perm;
ck_tile::HostTensor<uint8_t> block_relation_onehot =
sparge::build_block_map_meansim(q_host, k_host, p);
std::cout << "Converting block map to VSA LUT (delta, CPU)..." << std::endl;
auto vsa_lut_cpu = sparge::block_map_to_vsa_lut_delta(block_relation_onehot);
// Validate block map
std::cout << "\n--- Validating GPU block map vs CPU golden ---" << std::endl;
{
std::size_t bmap_mismatches = 0;
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
{
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
{
if(block_map_gpu(b, h, qb, kb) !=
block_relation_onehot(b, h, qb, kb))
{
bmap_mismatches++;
if(bmap_mismatches <= 10)
{
std::cout
<< " block_map mismatch at [" << b << "," << h << ","
<< qb << "," << kb
<< "]: GPU="
<< static_cast<int>(block_map_gpu(b, h, qb, kb))
<< " CPU="
<< static_cast<int>(
block_relation_onehot(b, h, qb, kb))
<< std::endl;
}
}
}
}
}
}
std::cout << " Block map mismatches: " << bmap_mismatches << " / "
<< (batch * nhead * num_q_blocks * num_k_blocks) << std::endl;
if(bmap_mismatches > 0)
{
std::cout << ">>> GPU BLOCK MAP VALIDATION FAILED <<<" << std::endl;
pass = false;
}
else
{
std::cout << ">>> GPU BLOCK MAP VALIDATION PASSED <<<" << std::endl;
}
}
// Validate VSA LUT
std::cout << "\n--- Validating GPU VSA LUT vs CPU golden ---" << std::endl;
{
std::size_t lut_mismatches = 0;
std::size_t valid_mismatches = 0;
for(ck_tile::index_t b = 0; b < batch; ++b)
{
for(ck_tile::index_t h = 0; h < nhead; ++h)
{
for(ck_tile::index_t qb = 0; qb < num_q_blocks; ++qb)
{
if(vsa_lut_gpu.valid_block_num(b, h, qb) !=
vsa_lut_cpu.valid_block_num(b, h, qb))
{
valid_mismatches++;
if(valid_mismatches <= 5)
{
std::cout
<< " valid_block_num mismatch at [" << b << "," << h
<< "," << qb
<< "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb)
<< " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb)
<< std::endl;
}
}
for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb)
{
if(vsa_lut_gpu.lut(b, h, qb, kb) !=
vsa_lut_cpu.lut(b, h, qb, kb))
{
lut_mismatches++;
if(lut_mismatches <= 10)
{
std::cout
<< " LUT mismatch at [" << b << "," << h << "," << qb
<< "," << kb
<< "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb)
<< " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb)
<< std::endl;
}
}
}
}
}
}
std::cout << " LUT mismatches: " << lut_mismatches << std::endl;
std::cout << " valid_block_num mismatches: " << valid_mismatches << std::endl;
if(lut_mismatches == 0 && valid_mismatches == 0)
{
std::cout << ">>> GPU VSA LUT VALIDATION PASSED <<<" << std::endl;
}
else
{
std::cout << ">>> GPU VSA LUT VALIDATION FAILED <<<" << std::endl;
pass = false;
}
}
// Validate attention output
float scale = 1.0f / std::sqrt(static_cast<float>(hdim_q));
std::cout << "Computing reference output..." << std::endl;
std::cout << "\nComputing reference attention output..." << std::endl;
auto q_ref = to_bhsd(q_host, i_perm);
auto k_ref = to_bhsd(k_host, i_perm);
auto v_ref = to_bhsd(v_host, i_perm);
ck_tile::HostTensor<T> output_ref({batch, nhead, seqlen_q, hdim_v});
ck_tile::reference_blocked_attention<T, uint8_t>(
q_ref, k_ref, v_ref, block_relation_onehot, output_ref, BLKQ, BLKK, scale);
@@ -374,7 +510,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
}
}
std::cout << "\nValidation results:" << std::endl;
std::cout << "\nAttention validation results:" << std::endl;
std::cout << " Max absolute difference: " << max_diff << std::endl;
std::cout << " Max relative difference: " << max_rel_diff << std::endl;
std::cout << " Number of mismatches: " << num_errors << " / "