mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
More compilation fixes
This commit is contained in:
@@ -178,6 +178,15 @@
|
||||
|
||||
# --- Unified Attention target (kept) ---
|
||||
|
||||
#
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
# Currently only gfx9 archs are supported by FMHA
|
||||
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9")
|
||||
if(NOT INST_TARGETS)
|
||||
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(EXAMPLE_FMHA_FWD_V3 "tile_example_unified_attention")
|
||||
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}")
|
||||
|
||||
|
||||
@@ -154,7 +154,6 @@ struct Problem
|
||||
float scale_k;
|
||||
float scale_v;
|
||||
mask_info mask;
|
||||
TensorLayout output_layout;
|
||||
std::vector<int> query_lens;
|
||||
std::vector<int> kv_lens;
|
||||
};
|
||||
@@ -350,8 +349,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
args.stride_v_cache_3 = args.stride_k_cache_3;
|
||||
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
args.output_stride_0 = query_stride_0;
|
||||
args.output_stride_1 = query_stride_1;
|
||||
args.output_stride_0 = args.query_stride_0;
|
||||
args.output_stride_1 = args.query_stride_1;
|
||||
|
||||
// Optional cumulative seqlen overrides (exclude PAD)
|
||||
auto make_effective_vec = [&](const std::vector<int>& opt_vec, ck_tile::index_t fallback) {
|
||||
@@ -386,19 +385,19 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
};
|
||||
calculate_cumulative(eff_query_lens, cu_query_lens);
|
||||
|
||||
ck_tile::DeviceMem seq_lens_buf(kv_lens.size());
|
||||
ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size());
|
||||
ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size());
|
||||
|
||||
seq_lens_buf.ToDevice(kv_lens.data());
|
||||
seq_lens_buf.ToDevice(eff_kv_lens.data());
|
||||
query_start_len_buf.ToDevice(cu_query_lens.data());
|
||||
|
||||
args.seq_lens_ptr =reinterpret_cast<const ck_tile::index_t*>(seq_lens_buf.GetDeviceBuffer());
|
||||
args.query_start_len_ptr =reinterpret_cast<const ck_tile::index_t*>(query_start_len_buf.GetDeviceBuffer());
|
||||
|
||||
|
||||
auto max_kv_len = std::max_element(problem.kv_lens.begin(), problem.kv_lens.end());
|
||||
int max_kv_len = std::max_element(eff_kv_lens.begin(), eff_kv_lens.end());
|
||||
|
||||
index_t max_num_blocks_per_seq = (max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE
|
||||
ck_tile::index_t max_num_blocks_per_seq = (max_kv_len + problem.BLOCK_SIZE - 1) / problem.BLOCK_SIZE;
|
||||
|
||||
// Create block_tables
|
||||
ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq * sizeof(ck_tile::index_t));
|
||||
@@ -433,30 +432,24 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
return false;
|
||||
}
|
||||
|
||||
std::size_t flop = [&] {
|
||||
if(problem.mask.type == mask_enum::no_mask)
|
||||
{
|
||||
return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
|
||||
problem.hdim;
|
||||
}
|
||||
else
|
||||
{
|
||||
/// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2.
|
||||
return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
|
||||
problem.hdim;
|
||||
}
|
||||
}();
|
||||
// std::size_t flop = [&] {
|
||||
// if(problem.mask.type == mask_enum::no_mask)
|
||||
// {
|
||||
// return 4 * args.num_tokens * problem.nhead_q *
|
||||
// problem.hdim;
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// /// FIXME: Use a more accurate method; for now, we’re just dividing the flop by 2.
|
||||
// return 2 * args.num_tokens * problem.nhead_q *
|
||||
// problem.hdim;
|
||||
// }
|
||||
// }();
|
||||
// TODO fix this
|
||||
std::size_t flop = 1;
|
||||
float tflops = static_cast<float>(flop) / 1.e9 / time;
|
||||
|
||||
std::cout << "[" << problem.data_type << "|";
|
||||
if(problem.input_layout == problem.output_layout)
|
||||
{
|
||||
std::cout << problem.input_layout;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << problem.input_layout << "-" << problem.output_layout;
|
||||
}
|
||||
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
|
||||
<< ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim
|
||||
<< ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed
|
||||
@@ -469,85 +462,70 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
}
|
||||
|
||||
// transpose tensor descriptors from bhsd to bshd if necessary
|
||||
if(problem.input_layout != TensorLayout::bshd)
|
||||
{
|
||||
q = q.transpose({0, 2, 1, 3});
|
||||
k = k.transpose({0, 2, 1, 3});
|
||||
v = v.transpose({0, 2, 1, 3});
|
||||
}
|
||||
// if(problem.input_layout != TensorLayout::bshd)
|
||||
// {
|
||||
// q = q.transpose({0, 2, 1, 3});
|
||||
// k = k.transpose({0, 2, 1, 3});
|
||||
// v = v.transpose({0, 2, 1, 3});
|
||||
// }
|
||||
|
||||
ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
|
||||
if(problem.output_layout != TensorLayout::bshd)
|
||||
{
|
||||
o_ref = o_ref.transpose({0, 2, 1, 3});
|
||||
}
|
||||
// ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
|
||||
// if(problem.output_layout != TensorLayout::bshd)
|
||||
// {
|
||||
// o_ref = o_ref.transpose({0, 2, 1, 3});
|
||||
// }
|
||||
|
||||
// If variable lengths are provided, compute per-batch references
|
||||
// with the effective lengths; else compute a single full reference.
|
||||
if(has_varlen_q || has_varlen_k)
|
||||
// Variable-length aware verification: zero-fill padded region and only compute valid part.
|
||||
o_ref.SetZero();
|
||||
|
||||
for(int b = 0; b < problem.batch; ++b)
|
||||
{
|
||||
// Variable-length aware verification: zero-fill padded region and only compute valid part.
|
||||
o_ref.SetZero();
|
||||
const ck_tile::index_t seqlen_q_eff = eff_q_vec[b];
|
||||
const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b];
|
||||
|
||||
for(int b = 0; b < problem.batch; ++b)
|
||||
{
|
||||
const ck_tile::index_t seqlen_q_eff = eff_q_vec[b];
|
||||
const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b];
|
||||
if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
|
||||
continue;
|
||||
|
||||
if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
|
||||
continue;
|
||||
// Slice current batch from inputs (bshd) and build single-batch tensors
|
||||
ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
|
||||
|
||||
// Slice current batch from inputs (bshd) and build single-batch tensors
|
||||
ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
|
||||
// Copy effective region
|
||||
q_b.ForEach([&](auto& self, auto idx) {
|
||||
// idx: [0, s, h, d]
|
||||
self(idx) = q(b, idx[1], idx[2], idx[3]);
|
||||
});
|
||||
k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); });
|
||||
v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
|
||||
|
||||
// Copy effective region
|
||||
q_b.ForEach([&](auto& self, auto idx) {
|
||||
// idx: [0, s, h, d]
|
||||
self(idx) = q(b, idx[1], idx[2], idx[3]);
|
||||
});
|
||||
k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); });
|
||||
v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
|
||||
|
||||
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
|
||||
host::fmha_fwd<float, DataType>(q_b,
|
||||
k_b,
|
||||
v_b,
|
||||
problem.mask,
|
||||
o_b,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales{problem.softmax_scale});
|
||||
|
||||
// Scatter into o_ref's bshd descriptor memory
|
||||
for(int s = 0; s < seqlen_q_eff; ++s)
|
||||
{
|
||||
for(int h = 0; h < problem.nhead_q; ++h)
|
||||
{
|
||||
for(int d = 0; d < problem.hdim; ++d)
|
||||
{
|
||||
o_ref(b, s, h, d) = o_b(0, s, h, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// No varlen override: compute the full reference once
|
||||
host::fmha_fwd<float, DataType>(q,
|
||||
k,
|
||||
v,
|
||||
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
|
||||
host::fmha_fwd<float, DataType>(q_b,
|
||||
k_b,
|
||||
v_b,
|
||||
problem.mask,
|
||||
o_ref,
|
||||
o_b,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales{problem.softmax_scale});
|
||||
|
||||
// Scatter into o_ref's bshd descriptor memory
|
||||
for(int s = 0; s < seqlen_q_eff; ++s)
|
||||
{
|
||||
for(int h = 0; h < problem.nhead_q; ++h)
|
||||
{
|
||||
for(int d = 0; d < problem.hdim; ++d)
|
||||
{
|
||||
o_ref(b, s, h, d) = o_b(0, s, h, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
|
||||
o_buf.FromDevice(o.data());
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false, false>;
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false, true>;
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, true>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
#include "unified_attention.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
#define INST_unified_attention_DISPATCH(kernel_traits) \
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) \
|
||||
template <> \
|
||||
std::pair<bool, float> unified_attention_kernel_dispatch<kernel_traits>( \
|
||||
const unified_attention_args& args, const stream_config& config) \
|
||||
@@ -73,7 +73,6 @@ struct unified_attention_kernel_traits
|
||||
|
||||
using unified_attention_traits = TileUnifiedAttentionTraits<true, // kPadSeqLenQ_
|
||||
false, // kPadHeadDimQ
|
||||
false, // kStoreLSE_
|
||||
-1 // kBlockPerCu
|
||||
>;
|
||||
|
||||
@@ -110,6 +109,7 @@ struct unified_attention_kernel_traits
|
||||
template <typename Kernel>
|
||||
float unified_attention_kernel_launch(const unified_attention_args& args, const stream_config& config)
|
||||
{
|
||||
index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
@@ -123,6 +123,7 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const
|
||||
args.scale_k,
|
||||
args.scale_v,
|
||||
args.scale_out,
|
||||
total_num_q_blocks,
|
||||
args.query_stride_0,
|
||||
args.query_stride_1,
|
||||
args.stride_k_cache_0,
|
||||
@@ -141,9 +142,6 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const
|
||||
args.num_seqs
|
||||
);
|
||||
|
||||
index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs;
|
||||
|
||||
|
||||
dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;
|
||||
|
||||
@@ -156,12 +156,11 @@ struct UnifiedAttentionKernel
|
||||
stride_v_cache_3,
|
||||
output_stride_0,
|
||||
output_stride_1},
|
||||
{
|
||||
block_tables_ptr,
|
||||
seq_lens_ptr,
|
||||
query_start_len_ptr,
|
||||
num_seqs
|
||||
}};
|
||||
};
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -344,7 +343,7 @@ struct UnifiedAttentionKernel
|
||||
|
||||
|
||||
index_t query_len_padded = integer_divide_ceil(cur_batch_query_len, BLOCK_Q) * BLOCK_Q;
|
||||
const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0);
|
||||
// const bool is_query_len_padded = (cur_batch_query_len % BLOCK_Q == 0);
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = [&]() {
|
||||
@@ -359,7 +358,7 @@ struct UnifiedAttentionKernel
|
||||
q_dram_base,
|
||||
// block sizes
|
||||
make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED),
|
||||
sequence<is_query_len_padded, false, kPadHeadDimQ>{}
|
||||
sequence<true, false, kPadHeadDimQ>{}
|
||||
); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED)
|
||||
|
||||
const auto q_dram_merged = transform_tensor_view(
|
||||
@@ -486,7 +485,7 @@ struct UnifiedAttentionKernel
|
||||
o_dram_base,
|
||||
// block sizes
|
||||
make_tuple(BLOCK_Q, 1, HEAD_SIZE_PADDED),
|
||||
sequence<is_query_len_padded, false, kPadHeadDimQ>{}
|
||||
sequence<true, false, kPadHeadDimQ>{}
|
||||
); // pads to (seq_len_padded, num_head_q, HEAD_SIZE_PADDED)
|
||||
|
||||
const auto o_dram_merged = transform_tensor_view(
|
||||
|
||||
@@ -12,13 +12,11 @@ namespace ck_tile {
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadHeadDim_ /* paddding for hdim_v */,
|
||||
bool kStoreLSE_,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
struct TileUnifiedAttentionTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadHeadDim = kPadHeadDim_;
|
||||
static constexpr bool kStoreLSE = kStoreLSE_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -256,7 +256,6 @@ struct UnifiedAttentionPipeline
|
||||
using VDataType = ck_tile::remove_cvref_t<typename Problem::VDataType>;
|
||||
using SaccDataType = ck_tile::remove_cvref_t<typename Problem::SaccDataType>;
|
||||
using SMPLComputeDataType = ck_tile::remove_cvref_t<typename Problem::SMPLComputeDataType>;
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using PDataType = ck_tile::remove_cvref_t<typename Problem::PDataType>;
|
||||
using OaccDataType = ck_tile::remove_cvref_t<typename Problem::OaccDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename Problem::ODataType>;
|
||||
@@ -372,11 +371,9 @@ struct UnifiedAttentionPipeline
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename LSEElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction>
|
||||
@@ -1206,14 +1203,12 @@ struct UnifiedAttentionPipeline
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp>
|
||||
typename VDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const void* block_tables_ptr,
|
||||
index_t block_table_offset,
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
@@ -1228,7 +1223,6 @@ struct UnifiedAttentionPipeline
|
||||
identity{},
|
||||
block_tables_ptr,
|
||||
block_table_offset,
|
||||
lse_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
|
||||
@@ -51,8 +51,6 @@ struct UnifiedAttentionPipelineProblem
|
||||
static constexpr bool kPadHeadDim = Traits::kPadHeadDim;
|
||||
static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
|
||||
static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Traits::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
|
||||
Reference in New Issue
Block a user