diff --git a/example/ck_tile/50_sparse_attn/CMakeLists.txt b/example/ck_tile/50_sparse_attn/CMakeLists.txt index 169ed87ac3..f234f631b6 100644 --- a/example/ck_tile/50_sparse_attn/CMakeLists.txt +++ b/example/ck_tile/50_sparse_attn/CMakeLists.txt @@ -249,14 +249,12 @@ set(SPARGE_VSA_INSTANCES "tile_sparge_vsa_instances") add_library(${SPARGE_VSA_INSTANCES} OBJECT EXCLUDE_FROM_ALL ${SPARGE_VSA_GEN_BLOBS} - ${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp ) target_include_directories(${SPARGE_VSA_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR} ${PROJECT_SOURCE_DIR}/include/ck_tile/ops/sparse_attn ) set_source_files_properties(${SPARGE_VSA_GEN_BLOBS} PROPERTIES LANGUAGE HIP) -set_source_files_properties(${CMAKE_CURRENT_LIST_DIR}/vsa_sparge_attention.cpp PROPERTIES LANGUAGE HIP) set_property(TARGET ${SPARGE_VSA_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) target_compile_options(${SPARGE_VSA_INSTANCES} PRIVATE @@ -273,7 +271,6 @@ set(SPARGE_BLOCKMAP_INSTANCES "tile_sparge_blockmap_instances") add_library(${SPARGE_BLOCKMAP_INSTANCES} OBJECT EXCLUDE_FROM_ALL ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp - ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp ) target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR} @@ -281,7 +278,6 @@ target_include_directories(${SPARGE_BLOCKMAP_INSTANCES} PRIVATE ) set_source_files_properties( ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap_inst.cpp - ${CMAKE_CURRENT_LIST_DIR}/sparge_blockmap.cpp PROPERTIES LANGUAGE HIP ) set_property(TARGET ${SPARGE_BLOCKMAP_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS}) diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp deleted file mode 100644 index b9ac56c533..0000000000 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap.cpp +++ /dev/null @@ -1,156 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#include "sparge_blockmap.h" -#include "sparge_blockmap_trek.hpp" -#include "ck_tile/core.hpp" -#include "ck_tile/host/host_tensor.hpp" -#include "ck_tile/host/device_memory.hpp" -#include -#include - -template -sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor& TQ, - const ck_tile::HostTensor& TK, - ck_tile::HostTensor& block_map_out, - int batch, - int nhead_q, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - bool i_perm, - float simthreshd1, - float cdfthreshd, - float topk, - int blkq, - int blkk, - int log_level) -{ - static_assert(std::is_same_v || - std::is_same_v, - "sparge_blockmap_gpu supports fp16/bf16 only."); - - std::string data_type = "fp16"; - if constexpr(std::is_same_v) - { - data_type = "bf16"; - } - - const ck_tile::index_t num_q_blocks = ck_tile::integer_divide_ceil(seqlen_q, blkq); - const ck_tile::index_t num_k_blocks = ck_tile::integer_divide_ceil(seqlen_k, blkk); - - const float scale = 1.0f / std::sqrt(static_cast(hdim_q)); - - // Allocate device memory - ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); - - const std::size_t bmap_bytes = - static_cast(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(uint8_t); - const std::size_t lut_bytes = - static_cast(batch) * nhead_q * num_q_blocks * num_k_blocks * sizeof(int32_t); - const std::size_t valid_bytes = - static_cast(batch) * nhead_q * num_q_blocks * sizeof(int32_t); - - ck_tile::DeviceMem bmap_buf(bmap_bytes); - ck_tile::DeviceMem lut_buf(lut_bytes); - ck_tile::DeviceMem valid_buf(valid_bytes); - - q_buf.ToDevice(TQ.data()); - k_buf.ToDevice(TK.data()); - bmap_buf.SetZero(); - lut_buf.SetZero(); - valid_buf.SetZero(); - - // Compute strides (assumes BHSD if i_perm, BSHD otherwise) - const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead_q * hdim_q; - const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q; - const ck_tile::index_t nhead_stride_q = - i_perm ? static_cast(seqlen_q) * hdim_q : hdim_q; - const ck_tile::index_t nhead_stride_k = - i_perm ? static_cast(seqlen_k) * hdim_q : hdim_q; - const ck_tile::index_t batch_stride_q = - static_cast(nhead_q) * seqlen_q * hdim_q; - const ck_tile::index_t batch_stride_k = - static_cast(nhead_k) * seqlen_k * hdim_q; - - ck_tile::stream_config stream_config{nullptr, false, log_level, 0, 1, false}; - - sparge_blockmap_args args; - args.q_ptr = q_buf.GetDeviceBuffer(); - args.k_ptr = k_buf.GetDeviceBuffer(); - args.batch = batch; - args.seqlen_q = seqlen_q; - args.seqlen_k = seqlen_k; - args.hdim_q = hdim_q; - args.nhead_q = nhead_q; - args.nhead_k = nhead_k; - args.stride_q = stride_q; - args.stride_k = stride_k; - args.nhead_stride_q = nhead_stride_q; - args.nhead_stride_k = nhead_stride_k; - args.batch_stride_q = batch_stride_q; - args.batch_stride_k = batch_stride_k; - args.simthreshd1 = simthreshd1; - args.cdfthreshd = cdfthreshd; - args.topk = topk; - args.scale = scale; - args.block_map_ptr = bmap_buf.GetDeviceBuffer(); - args.lut_ptr = lut_buf.GetDeviceBuffer(); - args.valid_block_num_ptr = valid_buf.GetDeviceBuffer(); - - sparge_blockmap_traits traits; - traits.data_type = data_type; - traits.hdim_q = hdim_q; - - sparge_blockmap_fwd(traits, args, stream_config); - - // Copy results back to host - bmap_buf.FromDevice(block_map_out.data(), bmap_bytes); - - sparge::VSALut vsa_lut{ - ck_tile::HostTensor({batch, nhead_q, num_q_blocks, num_k_blocks}), - ck_tile::HostTensor({batch, nhead_q, num_q_blocks}), - }; - lut_buf.FromDevice(vsa_lut.lut.data(), lut_bytes); - valid_buf.FromDevice(vsa_lut.valid_block_num.data(), valid_bytes); - - return vsa_lut; -} - -// Explicit template instantiations -template sparge::VSALut -sparge_blockmap_gpu(const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - ck_tile::HostTensor&, - int, - int, - int, - int, - int, - int, - bool, - float, - float, - float, - int, - int, - int); - -template sparge::VSALut -sparge_blockmap_gpu(const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - ck_tile::HostTensor&, - int, - int, - int, - int, - int, - int, - bool, - float, - float, - float, - int, - int, - int); diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap.h b/example/ck_tile/50_sparse_attn/sparge_blockmap.h deleted file mode 100644 index 3057257ca1..0000000000 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#pragma once - -#include -#include "ck_tile/core.hpp" -#include "ck_tile/host/host_tensor.hpp" -#include "sparge_tool.hpp" - -template -sparge::VSALut sparge_blockmap_gpu(const ck_tile::HostTensor& TQ, - const ck_tile::HostTensor& TK, - ck_tile::HostTensor& block_map_out, - int batch, - int nhead_q, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - bool i_perm, - float simthreshd1, - float cdfthreshd, - float topk, - int blkq, - int blkk, - int log_level = 0); diff --git a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp index 638a867b0f..572b708f9e 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge_vsa_sparse_attn.cpp @@ -1,23 +1,17 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT -// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention +// Demo: Sparge block-map -> (delta LUT) -> VSA sparse attention (all-in-device) #include -#include #include -#include #include -#include -#include -#include - #include "ck_tile/host.hpp" #include "ck_tile/core.hpp" #include "ck_tile/host/reference/reference_blocked_attention.hpp" #include "ck_tile/core/utility/bit_cast.hpp" -#include "vsa_sparge_attention.h" -#include "sparge_blockmap.h" +#include "sparge_blockmap_trek.hpp" +#include "fmha_fwd_trek.hpp" #include "sparge_tool.hpp" // ============================================================================ @@ -192,7 +186,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) << ", topk=" << topk << ")" << std::endl; std::cout << " i_perm: " << i_perm << ", o_perm: " << o_perm << std::endl; - // Create host tensors + // Create host tensors and fill with random data ck_tile::HostTensor q_host = make_qkv_tensor(batch, nhead, seqlen_q, hdim_q, i_perm); ck_tile::HostTensor k_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_q, i_perm); ck_tile::HostTensor v_host = make_qkv_tensor(batch, nhead_k, seqlen_k, hdim_v, i_perm); @@ -206,119 +200,157 @@ bool run_test(const ck_tile::ArgParser& arg_parser) ck_tile::FillUniformDistribution{-0.5f, 0.5f, seed + 2}(v_host); // ================================================================== - // GPU: Build block map + VSA LUT in one kernel (always run) + // Allocate device memory once, HtoD once // ================================================================== - std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl; - ck_tile::HostTensor block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks}); - auto vsa_lut_gpu = sparge_blockmap_gpu(q_host, - k_host, - block_map_gpu, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - i_perm, - simthreshd1, - cdfthreshd, - topk, - static_cast(BLKQ), - static_cast(BLKK), - 0); + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(output_host.get_element_space_size_in_bytes()); + + q_buf.ToDevice(q_host.data()); + k_buf.ToDevice(k_host.data()); + v_buf.ToDevice(v_host.data()); + + const std::size_t bmap_bytes = + static_cast(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(uint8_t); + const std::size_t lut_bytes = + static_cast(batch) * nhead * num_q_blocks * num_k_blocks * sizeof(int32_t); + const std::size_t valid_bytes = + static_cast(batch) * nhead * num_q_blocks * sizeof(int32_t); + + ck_tile::DeviceMem bmap_buf(bmap_bytes); + ck_tile::DeviceMem lut_buf(lut_bytes); + ck_tile::DeviceMem valid_buf(valid_bytes); + bmap_buf.SetZero(); + lut_buf.SetZero(); + valid_buf.SetZero(); // ================================================================== - // VSA sparse attention kernel (always run) + // Common stride calculations + // ================================================================== + assert(nhead % nhead_k == 0); + const float scale_s = 1.0f / std::sqrt(static_cast(hdim_q)); + + const ck_tile::index_t stride_q = i_perm ? hdim_q : nhead * hdim_q; + const ck_tile::index_t stride_k = i_perm ? hdim_q : nhead_k * hdim_q; + const ck_tile::index_t stride_v = i_perm ? hdim_v : nhead_k * hdim_v; + const ck_tile::index_t stride_o = o_perm ? hdim_v : nhead * hdim_v; + const ck_tile::index_t nhead_stride_q = i_perm ? seqlen_q * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_k = i_perm ? seqlen_k * hdim_q : hdim_q; + const ck_tile::index_t nhead_stride_v = i_perm ? seqlen_k * hdim_v : hdim_v; + const ck_tile::index_t nhead_stride_o = o_perm ? seqlen_q * hdim_v : hdim_v; + const ck_tile::index_t batch_stride_q = nhead * seqlen_q * hdim_q; + const ck_tile::index_t batch_stride_k = nhead_k * seqlen_k * hdim_q; + const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * seqlen_k; + const ck_tile::index_t batch_stride_o = nhead * seqlen_q * hdim_v; + + std::string data_type = "fp16"; + if constexpr(std::is_same_v) + data_type = "bf16"; + + std::string msk_str = "0"; + mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); + + // ================================================================== + // GPU: Build block map + VSA LUT (always run, device-only) + // ================================================================== + std::cout << "Building Sparge block map + VSA LUT (GPU)..." << std::endl; + { + sparge_blockmap_args args; + args.q_ptr = q_buf.GetDeviceBuffer(); + args.k_ptr = k_buf.GetDeviceBuffer(); + args.batch = batch; + args.seqlen_q = seqlen_q; + args.seqlen_k = seqlen_k; + args.hdim_q = hdim_q; + args.nhead_q = nhead; + args.nhead_k = nhead_k; + args.stride_q = stride_q; + args.stride_k = stride_k; + args.nhead_stride_q = nhead_stride_q; + args.nhead_stride_k = nhead_stride_k; + args.batch_stride_q = batch_stride_q; + args.batch_stride_k = batch_stride_k; + args.simthreshd1 = simthreshd1; + args.cdfthreshd = cdfthreshd; + args.topk = topk; + args.scale = scale_s; + args.block_map_ptr = bmap_buf.GetDeviceBuffer(); + args.lut_ptr = lut_buf.GetDeviceBuffer(); + args.valid_block_num_ptr = valid_buf.GetDeviceBuffer(); + + sparge_blockmap_traits traits; + traits.data_type = data_type; + traits.hdim_q = hdim_q; + + sparge_blockmap_fwd(traits, args, ck_tile::stream_config{}); + } + + // ================================================================== + // VSA sparse attention kernel (always run, LUT stays on device) // ================================================================== std::cout << "\n--- Running VSA sparse attention kernel ---" << std::endl; - try - { - if(kname) - { - vsa_sparge_attention(q_host, - k_host, - v_host, - vsa_lut_gpu.lut, - vsa_lut_gpu.valid_block_num, - output_host, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - i_perm, - o_perm, - seqlen_q, - seqlen_k, - 1); - } + fmha_vsa_fwd_args fmha_args; + fmha_args.q_ptr = q_buf.GetDeviceBuffer(); + fmha_args.k_ptr = k_buf.GetDeviceBuffer(); + fmha_args.v_ptr = v_buf.GetDeviceBuffer(); + fmha_args.lut_ptr = lut_buf.GetDeviceBuffer(); + fmha_args.valid_block_num_ptr = valid_buf.GetDeviceBuffer(); + fmha_args.o_ptr = o_buf.GetDeviceBuffer(); + fmha_args.batch = batch; + fmha_args.seqlen_q = seqlen_q; + fmha_args.seqlen_k = seqlen_k; + fmha_args.max_seqlen_q = seqlen_q; + fmha_args.hdim_q = hdim_q; + fmha_args.hdim_v = hdim_v; + fmha_args.nhead_q = nhead; + fmha_args.nhead_k = nhead_k; + fmha_args.scale_s = scale_s; + fmha_args.stride_q = stride_q; + fmha_args.stride_k = stride_k; + fmha_args.stride_v = stride_v; + fmha_args.stride_o = stride_o; + fmha_args.nhead_stride_q = nhead_stride_q; + fmha_args.nhead_stride_k = nhead_stride_k; + fmha_args.nhead_stride_v = nhead_stride_v; + fmha_args.nhead_stride_o = nhead_stride_o; + fmha_args.batch_stride_q = batch_stride_q; + fmha_args.batch_stride_k = batch_stride_k; + fmha_args.batch_stride_v = batch_stride_v; + fmha_args.batch_stride_o = batch_stride_o; + fmha_args.window_size_left = mask.left; + fmha_args.window_size_right = mask.right; + fmha_args.mask_type = static_cast(mask.type); - for(int i = 0; i < warmup; ++i) - { - vsa_sparge_attention(q_host, - k_host, - v_host, - vsa_lut_gpu.lut, - vsa_lut_gpu.valid_block_num, - output_host, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - i_perm, - o_perm, - seqlen_q, - seqlen_k, - 0); - } + fmha_vsa_fwd_traits fmha_traits; + fmha_traits.hdim_q = hdim_q; + fmha_traits.hdim_v = hdim_v; + fmha_traits.data_type = data_type; + fmha_traits.is_v_rowmajor = true; + fmha_traits.mask_type = mask.type; - [[maybe_unused]] auto sync_status1 = hipDeviceSynchronize(); - auto start = std::chrono::high_resolution_clock::now(); + ck_tile::stream_config stream_config{nullptr, + true, + /* log_level = */ kname ? 1 : 0, + warmup, + repeat, + false}; - for(int i = 0; i < repeat; ++i) - { - vsa_sparge_attention(q_host, - k_host, - v_host, - vsa_lut_gpu.lut, - vsa_lut_gpu.valid_block_num, - output_host, - batch, - nhead, - nhead_k, - seqlen_q, - seqlen_k, - hdim_q, - hdim_v, - i_perm, - o_perm, - seqlen_q, - seqlen_k, - 0); - } + float avg_time_ms = sparge_vsa_fwd(fmha_traits, fmha_args, stream_config); - [[maybe_unused]] auto sync_status2 = hipDeviceSynchronize(); - auto end = std::chrono::high_resolution_clock::now(); - double avg_time_ms = - std::chrono::duration(end - start).count() / repeat; + std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<" + << std::endl; - std::cout << "\n>>>> VSA sparse attention average time: " << avg_time_ms << " ms <<<<" - << std::endl; - } - catch(const std::exception& e) - { - std::cerr << "Error during kernel execution: " << e.what() << std::endl; - return false; - } + // DtoH: attention output (always needed) + o_buf.FromDevice(output_host.data(), output_host.get_element_space_size_in_bytes()); + + // DtoH: block_map (needed for sparsity stats and validation) + ck_tile::HostTensor block_map_gpu({batch, nhead, num_q_blocks, num_k_blocks}); + bmap_buf.FromDevice(block_map_gpu.data(), bmap_bytes); // ================================================================== - // Sparsity statistics (always run, pure CPU read of HostTensor) + // Sparsity statistics (pure CPU, reads block_map HostTensor) // ================================================================== std::size_t total_blocks = 0; std::size_t active_blocks = 0; @@ -366,6 +398,14 @@ bool run_test(const ck_tile::ArgParser& arg_parser) std::cout << "Converting block map to VSA LUT (delta, CPU)..." << std::endl; auto vsa_lut_cpu = sparge::block_map_to_vsa_lut_delta(block_relation_onehot); + // DtoH: LUT + valid_block_num (only for validation) + sparge::VSALut vsa_lut_gpu{ + ck_tile::HostTensor({batch, nhead, num_q_blocks, num_k_blocks}), + ck_tile::HostTensor({batch, nhead, num_q_blocks}), + }; + lut_buf.FromDevice(vsa_lut_gpu.lut.data(), lut_bytes); + valid_buf.FromDevice(vsa_lut_gpu.valid_block_num.data(), valid_bytes); + // Validate block map std::cout << "\n--- Validating GPU block map vs CPU golden ---" << std::endl; { @@ -378,20 +418,16 @@ bool run_test(const ck_tile::ArgParser& arg_parser) { for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) { - if(block_map_gpu(b, h, qb, kb) != - block_relation_onehot(b, h, qb, kb)) + if(block_map_gpu(b, h, qb, kb) != block_relation_onehot(b, h, qb, kb)) { bmap_mismatches++; if(bmap_mismatches <= 10) { std::cout - << " block_map mismatch at [" << b << "," << h << "," - << qb << "," << kb - << "]: GPU=" - << static_cast(block_map_gpu(b, h, qb, kb)) - << " CPU=" - << static_cast( - block_relation_onehot(b, h, qb, kb)) + << " block_map mismatch at [" << b << "," << h << "," << qb + << "," << kb << "]: GPU=" + << static_cast(block_map_gpu(b, h, qb, kb)) << " CPU=" + << static_cast(block_relation_onehot(b, h, qb, kb)) << std::endl; } } @@ -429,28 +465,24 @@ bool run_test(const ck_tile::ArgParser& arg_parser) valid_mismatches++; if(valid_mismatches <= 5) { - std::cout - << " valid_block_num mismatch at [" << b << "," << h - << "," << qb - << "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb) - << " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb) - << std::endl; + std::cout << " valid_block_num mismatch at [" << b << "," << h + << "," << qb + << "]: GPU=" << vsa_lut_gpu.valid_block_num(b, h, qb) + << " CPU=" << vsa_lut_cpu.valid_block_num(b, h, qb) + << std::endl; } } for(ck_tile::index_t kb = 0; kb < num_k_blocks; ++kb) { - if(vsa_lut_gpu.lut(b, h, qb, kb) != - vsa_lut_cpu.lut(b, h, qb, kb)) + if(vsa_lut_gpu.lut(b, h, qb, kb) != vsa_lut_cpu.lut(b, h, qb, kb)) { lut_mismatches++; if(lut_mismatches <= 10) { std::cout << " LUT mismatch at [" << b << "," << h << "," << qb - << "," << kb - << "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb) - << " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb) - << std::endl; + << "," << kb << "]: GPU=" << vsa_lut_gpu.lut(b, h, qb, kb) + << " CPU=" << vsa_lut_cpu.lut(b, h, qb, kb) << std::endl; } } } diff --git a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp deleted file mode 100644 index 5f9c2676dd..0000000000 --- a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.cpp +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#include "vsa_sparge_attention.h" -#include "fmha_fwd_trek.hpp" -#include "ck_tile/core.hpp" -#include "ck_tile/host/host_tensor.hpp" -#include "ck_tile/host/device_memory.hpp" -#include - -template -ck_tile::HostTensor -vsa_sparge_attention(const ck_tile::HostTensor& TQ, - const ck_tile::HostTensor& TK, - const ck_tile::HostTensor& TV, - const ck_tile::HostTensor& TKV_block_idx, - const ck_tile::HostTensor& TKV_blocks, - ck_tile::HostTensor& Y, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - bool i_perm, - bool o_perm, - int max_seqlen_q, - int max_seqlen_k, - int log_level) -{ - static_assert(std::is_same_v || - std::is_same_v, - "VSA sparse attention supports fp16/bf16 only."); - std::string data_type = "fp16"; - if constexpr(std::is_same_v) - { - data_type = "bf16"; - } - - if(max_seqlen_q == 0) - max_seqlen_q = seqlen_q; - if(max_seqlen_k == 0) - max_seqlen_k = seqlen_k; - bool is_v_rowmajor = true; - float scale_s = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); - std::string msk_str = "0"; - mask_info mask = mask_info::decode(msk_str, seqlen_q, seqlen_k); - - const ck_tile::index_t shape_seqlen_q = seqlen_q; - const ck_tile::index_t shape_seqlen_k = seqlen_k; - - ck_tile::stream_config stream_config{nullptr, - false, // time_kernel - log_level, - 0, - 1, - false}; - - ck_tile::DeviceMem q_buf(TQ.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(TK.get_element_space_size_in_bytes()); - ck_tile::DeviceMem v_buf(TV.get_element_space_size_in_bytes()); - ck_tile::DeviceMem lut_buf(TKV_block_idx.get_element_space_size_in_bytes()); - ck_tile::DeviceMem valid_block_num_buf(TKV_blocks.get_element_space_size_in_bytes()); - ck_tile::DeviceMem o_buf(Y.get_element_space_size_in_bytes()); - - q_buf.ToDevice(TQ.data()); - k_buf.ToDevice(TK.data()); - v_buf.ToDevice(TV.data()); - lut_buf.ToDevice(TKV_block_idx.data()); - valid_block_num_buf.ToDevice(TKV_blocks.data()); - - const auto init_args = [&](auto& args) { - assert(nhead % nhead_k == 0); - const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); - const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); - const ck_tile::index_t stride_v = [&]() { - if(is_v_rowmajor) - return i_perm ? hdim_v : nhead_k * hdim_v; - else - return (i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k); - }(); - const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); - const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); - const ck_tile::index_t nhead_stride_k = i_perm ? shape_seqlen_k * hdim_q : hdim_q; - const ck_tile::index_t nhead_stride_v = [&]() { - if(is_v_rowmajor) - return i_perm ? shape_seqlen_k * hdim_v : hdim_v; - else - return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; - }(); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); - const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); - const ck_tile::index_t batch_stride_k = nhead_k * shape_seqlen_k * hdim_q; - const ck_tile::index_t batch_stride_v = nhead_k * hdim_v * shape_seqlen_k; - const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); - - args.q_ptr = q_buf.GetDeviceBuffer(); - args.k_ptr = k_buf.GetDeviceBuffer(); - args.v_ptr = v_buf.GetDeviceBuffer(); - args.lut_ptr = lut_buf.GetDeviceBuffer(); - args.valid_block_num_ptr = valid_block_num_buf.GetDeviceBuffer(); - - args.batch = batch; - args.seqlen_q = shape_seqlen_q; - args.hdim_q = hdim_q; - args.hdim_v = hdim_v; - args.nhead_q = nhead; - args.nhead_k = nhead_k; - - args.stride_q = stride_q; - args.stride_k = stride_k; - args.stride_v = stride_v; - args.nhead_stride_q = nhead_stride_q; - args.nhead_stride_k = nhead_stride_k; - args.nhead_stride_v = nhead_stride_v; - args.batch_stride_q = batch_stride_q; - args.batch_stride_k = batch_stride_k; - args.batch_stride_v = batch_stride_v; - - args.o_ptr = o_buf.GetDeviceBuffer(); - - args.seqlen_k = shape_seqlen_k; - args.max_seqlen_q = max_seqlen_q; - - args.scale_s = scale_s; - - args.stride_o = stride_o; - args.nhead_stride_o = nhead_stride_o; - args.batch_stride_o = batch_stride_o; - - args.window_size_left = mask.left; - args.window_size_right = mask.right; - args.mask_type = static_cast(mask.type); - }; - - const auto init_traits = [&](auto& traits) { - traits.hdim_q = hdim_q; - traits.hdim_v = hdim_v; - traits.data_type = data_type; - traits.is_v_rowmajor = is_v_rowmajor; - traits.mask_type = mask.type; - }; - - fmha_vsa_fwd_traits fmha_traits; - init_traits(fmha_traits); - - fmha_vsa_fwd_args args; - init_args(args); - - sparge_vsa_fwd(fmha_traits, args, stream_config); - - o_buf.FromDevice(Y.data(), Y.get_element_space_size_in_bytes()); - - return Y; -} - -template ck_tile::HostTensor -vsa_sparge_attention(const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - ck_tile::HostTensor&, - int, - int, - int, - int, - int, - int, - int, - bool, - bool, - int, - int, - int); - -template ck_tile::HostTensor -vsa_sparge_attention(const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - const ck_tile::HostTensor&, - ck_tile::HostTensor&, - int, - int, - int, - int, - int, - int, - int, - bool, - bool, - int, - int, - int); diff --git a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h b/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h deleted file mode 100644 index d51a7e8c00..0000000000 --- a/example/ck_tile/50_sparse_attn/vsa_sparge_attention.h +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT -#pragma once -#include -#include -#include "ck_tile/core.hpp" -#include "ck_tile/host/host_tensor.hpp" - -template -ck_tile::HostTensor -vsa_sparge_attention(const ck_tile::HostTensor& TQ, - const ck_tile::HostTensor& TK, - const ck_tile::HostTensor& TV, - const ck_tile::HostTensor& TKV_block_idx, - const ck_tile::HostTensor& TKV_blocks, - ck_tile::HostTensor& Y, - int batch, - int nhead, - int nhead_k, - int seqlen_q, - int seqlen_k, - int hdim_q, - int hdim_v, - bool i_perm, - bool o_perm, - int max_seqlen_q, - int max_seqlen_k, - int log_level = 0);