Log UA argument in the instance heuristic selection

This commit is contained in:
Damien Lejeune
2026-04-23 13:07:42 +00:00
parent 3f076a6fc1
commit ce751cf74d
3 changed files with 79 additions and 3 deletions

View File

@@ -4,6 +4,7 @@
#include "unified_attention.hpp"
#include "unified_attention_impl.hpp"
#include "mask.hpp"
#include <cstdint>
namespace ck_tile {
@@ -18,10 +19,75 @@ std::ostream& operator<<(std::ostream& stream,
}
}
namespace {
void write_ptr(std::ostream& s, const void* p)
{
if(p == nullptr)
{
s << "nullptr";
return;
}
const std::ios::fmtflags f = s.flags();
s << "0x" << std::hex << reinterpret_cast<std::uintptr_t>(p) << std::dec;
s.flags(f);
}
} // namespace
std::ostream& operator<<(std::ostream& stream, const unified_attention_args& args)
{
// Single line, comma-separated key=value (keys match struct fields for easy parsing).
stream << "unified_attention_args, "
<< " data_type=" << args.data_type
<< ", mask_type=" << args.mask_type
<< ", num_tokens=" << args.num_tokens
<< ", num_blks=" << args.num_blks
<< ", num_head_q=" << args.num_head_q
<< ", num_queries_per_kv=" << args.num_queries_per_kv
<< ", page_blk_size=" << args.page_blk_size
<< ", hdim=" << args.hdim
<< ", scale_s=" << args.scale_s
<< ", scale=" << args.scale
<< ", scale_k=" << args.scale_k
<< ", scale_v=" << args.scale_v
<< ", scale_out=" << args.scale_out
// << ", q_ptr=";
// write_ptr(stream, args.q_ptr);
<< ", query_stride_0=" << args.query_stride_0
<< ", query_stride_1=" << args.query_stride_1
// << ", k_ptr=";
// write_ptr(stream, args.k_ptr);
<< ", stride_k_cache_0=" << args.stride_k_cache_0
<< ", stride_k_cache_1=" << args.stride_k_cache_1
<< ", stride_k_cache_2=" << args.stride_k_cache_2
<< ", stride_k_cache_3=" << args.stride_k_cache_3
// << ", v_ptr=";
// write_ptr(stream, args.v_ptr);
<< ", stride_v_cache_0=" << args.stride_v_cache_0
<< ", stride_v_cache_1=" << args.stride_v_cache_1
<< ", stride_v_cache_2=" << args.stride_v_cache_2
<< ", stride_v_cache_3=" << args.stride_v_cache_3
// << ", o_ptr=";
// write_ptr(stream, args.o_ptr);
<< ", output_stride_0=" << args.output_stride_0
<< ", output_stride_1=" << args.output_stride_1
// << ", block_tables_ptr=";
// write_ptr(stream, static_cast<const void*>(args.block_tables_ptr));
<< ", block_table_stride=" << args.block_table_stride;
// << ", seq_lens_ptr=";
// write_ptr(stream, static_cast<const void*>(args.seq_lens_ptr));
// stream << ", query_start_len_ptr=";
// write_ptr(stream, static_cast<const void*>(args.query_start_len_ptr));
return stream << ", num_seqs=" << args.num_seqs
<< ", max_seqlen_q=" << args.max_seqlen_q << " }";
}
// Helper macro to reduce dispatch boilerplate.
// Dispatches based on DataType, IsMasking, HeadSize, BlockM, NumQPerKV.
#define DISPATCH_UNIFIED_ATTENTION(DType, IsMask, HSize, BM, NQPKV) \
{ \
std::cout << "DISPATCH_UNIFIED_ATTENTION: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \
using kernel_traits = unified_attention_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
}
@@ -29,18 +95,21 @@ std::ostream& operator<<(std::ostream& stream,
// Dispatch macros for three tile tiers (default block_size).
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType, IsMask, HSize, BM, NQPKV) \
{ \
std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
}
#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType, IsMask, HSize, BM, NQPKV) \
{ \
std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
#define DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType, IsMask, HSize, BM, NQPKV) \
{ \
std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_TINY: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \
using kernel_traits = unified_attention_decode_tiny_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
@@ -48,18 +117,21 @@ std::ostream& operator<<(std::ostream& stream,
// block_size=32 dispatch macros (6th template arg = 32).
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType, IsMask, HSize, BM, NQPKV) \
{ \
std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
}
#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType, IsMask, HSize, BM, NQPKV) \
{ \
std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}
#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType, IsMask, HSize, BM, NQPKV) \
{ \
std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
}

View File

@@ -72,6 +72,8 @@ struct unified_attention_args
std::ostream& operator<<(std::ostream& stream,
const unified_attention_args::data_type_enum& data_type);
std::ostream& operator<<(std::ostream& stream, const unified_attention_args& args);
// return value:
// first = whether the kernel was launched (true = launched, false = skipped)
// second = elapsed time (ms) of the kernel launch, valid only if first == true

View File

@@ -5,6 +5,7 @@
#include <utility>
#include <iostream>
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/container/sequence.hpp"
@@ -33,9 +34,10 @@
template <> \
std::pair<bool, float> unified_attention_kernel_dispatch_decode<kernel_traits>( \
const unified_attention_args& args, const stream_config& config) \
{ \
return std::make_pair( \
true, unified_attention_kernel_launch<kernel_traits::kernel, true>(args, config)); \
{ \
std::cout << "INST_UNIFIED_ATTENTION_DISPATCH_DECODE, " << args << std::endl; \
return std::make_pair( \
true, unified_attention_kernel_launch<kernel_traits::kernel, true>(args, config)); \
}
namespace ck_tile {