More compilation fixes

This commit is contained in:
Tianxing Wu
2025-10-20 15:53:35 +00:00
parent d68a541c19
commit f72b994b00
9 changed files with 89 additions and 115 deletions

View File

@@ -178,6 +178,15 @@
# --- Unified Attention target (kept) ---
#
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
# Currently only gfx9 archs are supported by FMHA
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9")
if(NOT INST_TARGETS)
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()
set(EXAMPLE_FMHA_FWD_V3 "tile_example_unified_attention")
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}")

View File

@@ -154,7 +154,6 @@ struct Problem
float scale_k;
float scale_v;
mask_info mask;
TensorLayout output_layout;
std::vector<int> query_lens;
std::vector<int> kv_lens;
};
@@ -350,8 +349,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
args.stride_v_cache_3 = args.stride_k_cache_3;
args.o_ptr = o_buf.GetDeviceBuffer();
args.output_stride_0 = query_stride_0;
args.output_stride_1 = query_stride_1;
args.output_stride_0 = args.query_stride_0;
args.output_stride_1 = args.query_stride_1;
// Optional cumulative seqlen overrides (exclude PAD)
auto make_effective_vec = [&](const std::vector<int>& opt_vec, ck_tile::index_t fallback) {
@@ -386,19 +385,19 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
};
calculate_cumulative(eff_query_lens, cu_query_lens);
ck_tile::DeviceMem seq_lens_buf(kv_lens.size());
ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size());
ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size());
seq_lens_buf.ToDevice(kv_lens.data());
seq_lens_buf.ToDevice(eff_kv_lens.data());
query_start_len_buf.ToDevice(cu_query_lens.data());
args.seq_lens_ptr =reinterpret_cast<const ck_tile::index_t*>(seq_lens_buf.GetDeviceBuffer());
args.query_start_len_ptr =reinterpret_cast<const ck_tile::index_t*>(query_start_len_buf.GetDeviceBuffer());
auto max_kv_len = std::max_element(problem.kv_lens.begin(), problem.kv_lens.end());
int max_kv_len = std::max_element(eff_kv_lens.begin(), eff_kv_lens.end());
index_t max_num_blocks_per_seq = (max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE
ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE;
// Create block_tables
ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq * sizeof(ck_tile::index_t));
@@ -433,30 +432,24 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
return false;
}
std::size_t flop = [&] {
if(problem.mask.type == mask_enum::no_mask)
{
return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
problem.hdim;
}
else
{
/// FIXME: Use a more accurate method; for now, were just dividing the flop by 2.
return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
problem.hdim;
}
}();
// std::size_t flop = [&] {
// if(problem.mask.type == mask_enum::no_mask)
// {
// return 4 * args.num_tokens * problem.nhead_q *
// problem.hdim;
// }
// else
// {
// /// FIXME: Use a more accurate method; for now, were just dividing the flop by 2.
// return 2 * args.num_tokens * problem.nhead_q *
// problem.hdim;
// }
// }();
// TODO fix this
std::size_t flop = 1;
float tflops = static_cast<float>(flop) / 1.e9 / time;
std::cout << "[" << problem.data_type << "|";
if(problem.input_layout == problem.output_layout)
{
std::cout << problem.input_layout;
}
else
{
std::cout << problem.input_layout << "-" << problem.output_layout;
}
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
<< ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim
<< ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed
@@ -469,85 +462,70 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
}
// transpose tensor descriptors from bhsd to bshd if necessary
if(problem.input_layout != TensorLayout::bshd)
{
q = q.transpose({0, 2, 1, 3});
k = k.transpose({0, 2, 1, 3});
v = v.transpose({0, 2, 1, 3});
}
// if(problem.input_layout != TensorLayout::bshd)
// {
// q = q.transpose({0, 2, 1, 3});
// k = k.transpose({0, 2, 1, 3});
// v = v.transpose({0, 2, 1, 3});
// }
ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
if(problem.output_layout != TensorLayout::bshd)
{
o_ref = o_ref.transpose({0, 2, 1, 3});
}
// ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
// if(problem.output_layout != TensorLayout::bshd)
// {
// o_ref = o_ref.transpose({0, 2, 1, 3});
// }
// If variable lengths are provided, compute per-batch references
// with the effective lengths; else compute a single full reference.
if(has_varlen_q || has_varlen_k)
// Variable-length aware verification: zero-fill padded region and only compute valid part.
o_ref.SetZero();
for(int b = 0; b < problem.batch; ++b)
{
// Variable-length aware verification: zero-fill padded region and only compute valid part.
o_ref.SetZero();
const ck_tile::index_t seqlen_q_eff = eff_q_vec[b];
const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b];
for(int b = 0; b < problem.batch; ++b)
{
const ck_tile::index_t seqlen_q_eff = eff_q_vec[b];
const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b];
if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
continue;
if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
continue;
// Slice current batch from inputs (bshd) and build single-batch tensors
ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
// Slice current batch from inputs (bshd) and build single-batch tensors
ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
// Copy effective region
q_b.ForEach([&](auto& self, auto idx) {
// idx: [0, s, h, d]
self(idx) = q(b, idx[1], idx[2], idx[3]);
});
k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); });
v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
// Copy effective region
q_b.ForEach([&](auto& self, auto idx) {
// idx: [0, s, h, d]
self(idx) = q(b, idx[1], idx[2], idx[3]);
});
k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); });
v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
host::fmha_fwd<float, DataType>(q_b,
k_b,
v_b,
problem.mask,
o_b,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales{problem.softmax_scale});
// Scatter into o_ref's bshd descriptor memory
for(int s = 0; s < seqlen_q_eff; ++s)
{
for(int h = 0; h < problem.nhead_q; ++h)
{
for(int d = 0; d < problem.hdim; ++d)
{
o_ref(b, s, h, d) = o_b(0, s, h, d);
}
}
}
}
}
else
{
// No varlen override: compute the full reference once
host::fmha_fwd<float, DataType>(q,
k,
v,
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
host::fmha_fwd<float, DataType>(q_b,
k_b,
v_b,
problem.mask,
o_ref,
o_b,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales{problem.softmax_scale});
// Scatter into o_ref's bshd descriptor memory
for(int s = 0; s < seqlen_q_eff; ++s)
{
for(int h = 0; h < problem.nhead_q; ++h)
{
for(int d = 0; d < problem.hdim; ++d)
{
o_ref(b, s, h, d) = o_b(0, s, h, d);
}
}
}
}
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
o_buf.FromDevice(o.data());

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false, false>;
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)

View File

@@ -7,7 +7,7 @@
namespace ck_tile {
using kernel_traits =
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false, true>;
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, true>;
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)

View File

@@ -20,7 +20,7 @@
#include "unified_attention.hpp"
#include "mask.hpp"
#define INST_unified_attention_DISPATCH(kernel_traits) \
#define INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) \
template <> \
std::pair<bool, float> unified_attention_kernel_dispatch<kernel_traits>( \
const unified_attention_args& args, const stream_config& config) \
@@ -73,7 +73,6 @@ struct unified_attention_kernel_traits
using unified_attention_traits = TileUnifiedAttentionTraits<true, // kPadSeqLenQ_
false, // kPadHeadDimQ
false, // kStoreLSE_
-1 // kBlockPerCu
>;
@@ -110,6 +109,7 @@ struct unified_attention_kernel_traits
template <typename Kernel>
float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config)
{
index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs;
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
@@ -123,6 +123,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const
args.scale_k,
args.scale_v,
args.scale_out,
total_num_q_blocks,
args.query_stride_0,
args.query_stride_1,
args.stride_k_cache_0,
@@ -141,9 +142,6 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const
args.num_seqs
);
index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs;
dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;

View File

@@ -156,12 +156,11 @@ struct UnifiedAttentionKernel
stride_v_cache_3,
output_stride_0,
output_stride_1},
{
block_tables_ptr,
seq_lens_ptr,
query_start_len_ptr,
num_seqs
}};
};
return kargs;
}
@@ -344,7 +343,7 @@ struct UnifiedAttentionKernel
index_t query_len_padded = integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q;
const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0);
// const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0);
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
@@ -359,7 +358,7 @@ struct UnifiedAttentionKernel
q_dram_base,
// block sizes
make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED),
sequence<is_query_len_padded, false, kPadHeadDimQ>{}
sequence<true, false, kPadHeadDimQ>{}
); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED)
const auto q_dram_merged = transform_tensor_view(
@@ -486,7 +485,7 @@ struct UnifiedAttentionKernel
o_dram_base,
// block sizes
make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED),
sequence<is_query_len_padded, false, kPadHeadDimQ>{}
sequence<true, false, kPadHeadDimQ>{}
); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED)
const auto o_dram_merged = transform_tensor_view(

View File

@@ -12,13 +12,11 @@ namespace ck_tile {
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDim_ /* paddding for hdim_v */,
bool kStoreLSE_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileUnifiedAttentionTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDim = kPadHeadDim_;
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
}

View File

@@ -256,7 +256,6 @@ struct UnifiedAttentionPipeline
using VDataType = ck_tile::remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = ck_tile::remove_cvref_t<typename Problem::SMPLComputeDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = ck_tile::remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = ck_tile::remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = ck_tile::remove_cvref_t<typename Problem::ODataType>;
@@ -372,11 +371,9 @@ struct UnifiedAttentionPipeline
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction>
@@ -1206,14 +1203,12 @@ struct UnifiedAttentionPipeline
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename LSEDramBlockWindowTmp>
typename VDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const void* block_tables_ptr,
index_t block_table_offset,
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
float scale_s,
void* smem_ptr) const
@@ -1228,7 +1223,6 @@ struct UnifiedAttentionPipeline
identity{},
block_tables_ptr,
block_table_offset,
lse_dram_block_window_tmp,
identity{},
identity{},
identity{},

View File

@@ -51,8 +51,6 @@ struct UnifiedAttentionPipelineProblem
static constexpr bool kPadHeadDim = Traits::kPadHeadDim;
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;