From ce751cf74df74405be41ebc5c1688561909d70f1 Mon Sep 17 00:00:00 2001 From: Damien Lejeune Date: Thu, 23 Apr 2026 13:07:42 +0000 Subject: [PATCH] Log UA argument in the instance heuristic selection --- .../unified_attention.cpp | 72 +++++++++++++++++++ .../unified_attention.hpp | 2 + .../unified_attention_impl.hpp | 8 ++- 3 files changed, 79 insertions(+), 3 deletions(-) diff --git a/example/ck_tile/42_unified_attention/unified_attention.cpp b/example/ck_tile/42_unified_attention/unified_attention.cpp index bdeb56aed9..7bb33f8dd0 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.cpp +++ b/example/ck_tile/42_unified_attention/unified_attention.cpp @@ -4,6 +4,7 @@ #include "unified_attention.hpp" #include "unified_attention_impl.hpp" #include "mask.hpp" +#include namespace ck_tile { @@ -18,10 +19,75 @@ std::ostream& operator<<(std::ostream& stream, } } +namespace { + +void write_ptr(std::ostream& s, const void* p) +{ + if(p == nullptr) + { + s << "nullptr"; + return; + } + const std::ios::fmtflags f = s.flags(); + s << "0x" << std::hex << reinterpret_cast(p) << std::dec; + s.flags(f); +} + +} // namespace + +std::ostream& operator<<(std::ostream& stream, const unified_attention_args& args) +{ + // Single line, comma-separated key=value (keys match struct fields for easy parsing). + stream << "unified_attention_args, " + << " data_type=" << args.data_type + << ", mask_type=" << args.mask_type + << ", num_tokens=" << args.num_tokens + << ", num_blks=" << args.num_blks + << ", num_head_q=" << args.num_head_q + << ", num_queries_per_kv=" << args.num_queries_per_kv + << ", page_blk_size=" << args.page_blk_size + << ", hdim=" << args.hdim + << ", scale_s=" << args.scale_s + << ", scale=" << args.scale + << ", scale_k=" << args.scale_k + << ", scale_v=" << args.scale_v + << ", scale_out=" << args.scale_out + // << ", q_ptr="; + // write_ptr(stream, args.q_ptr); + << ", query_stride_0=" << args.query_stride_0 + << ", query_stride_1=" << args.query_stride_1 + // << ", k_ptr="; + // write_ptr(stream, args.k_ptr); + << ", stride_k_cache_0=" << args.stride_k_cache_0 + << ", stride_k_cache_1=" << args.stride_k_cache_1 + << ", stride_k_cache_2=" << args.stride_k_cache_2 + << ", stride_k_cache_3=" << args.stride_k_cache_3 + // << ", v_ptr="; + // write_ptr(stream, args.v_ptr); + << ", stride_v_cache_0=" << args.stride_v_cache_0 + << ", stride_v_cache_1=" << args.stride_v_cache_1 + << ", stride_v_cache_2=" << args.stride_v_cache_2 + << ", stride_v_cache_3=" << args.stride_v_cache_3 + // << ", o_ptr="; + // write_ptr(stream, args.o_ptr); + << ", output_stride_0=" << args.output_stride_0 + << ", output_stride_1=" << args.output_stride_1 + // << ", block_tables_ptr="; + // write_ptr(stream, static_cast(args.block_tables_ptr)); + << ", block_table_stride=" << args.block_table_stride; + // << ", seq_lens_ptr="; + // write_ptr(stream, static_cast(args.seq_lens_ptr)); + // stream << ", query_start_len_ptr="; + // write_ptr(stream, static_cast(args.query_start_len_ptr)); + return stream << ", num_seqs=" << args.num_seqs + << ", max_seqlen_q=" << args.max_seqlen_q << " }"; +} + // Helper macro to reduce dispatch boilerplate. // Dispatches based on DataType, IsMasking, HeadSize, BlockM, NumQPerKV. #define DISPATCH_UNIFIED_ATTENTION(DType, IsMask, HSize, BM, NQPKV) \ { \ + std::cout << "DISPATCH_UNIFIED_ATTENTION: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \ using kernel_traits = unified_attention_kernel_traits; \ return unified_attention_kernel_dispatch(args, config); \ } @@ -29,18 +95,21 @@ std::ostream& operator<<(std::ostream& stream, // Dispatch macros for three tile tiers (default block_size). #define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType, IsMask, HSize, BM, NQPKV) \ { \ + std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \ using kernel_traits = unified_attention_decode_kernel_traits; \ return unified_attention_kernel_dispatch(args, config); \ } #define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType, IsMask, HSize, BM, NQPKV) \ { \ + std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \ using kernel_traits = unified_attention_decode_small_kernel_traits; \ return unified_attention_kernel_dispatch_decode(args, config); \ } #define DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType, IsMask, HSize, BM, NQPKV) \ { \ + std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_TINY: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \ using kernel_traits = unified_attention_decode_tiny_kernel_traits; \ return unified_attention_kernel_dispatch_decode(args, config); \ } @@ -48,18 +117,21 @@ std::ostream& operator<<(std::ostream& stream, // block_size=32 dispatch macros (6th template arg = 32). #define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType, IsMask, HSize, BM, NQPKV) \ { \ + std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \ using kernel_traits = unified_attention_decode_kernel_traits; \ return unified_attention_kernel_dispatch(args, config); \ } #define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType, IsMask, HSize, BM, NQPKV) \ { \ + std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \ using kernel_traits = unified_attention_decode_small_kernel_traits; \ return unified_attention_kernel_dispatch_decode(args, config); \ } #define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType, IsMask, HSize, BM, NQPKV) \ { \ + std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \ using kernel_traits = unified_attention_decode_bs32_kernel_traits; \ return unified_attention_kernel_dispatch_decode(args, config); \ } diff --git a/example/ck_tile/42_unified_attention/unified_attention.hpp b/example/ck_tile/42_unified_attention/unified_attention.hpp index 8b645387a4..7bf52d09a6 100644 --- a/example/ck_tile/42_unified_attention/unified_attention.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention.hpp @@ -72,6 +72,8 @@ struct unified_attention_args std::ostream& operator<<(std::ostream& stream, const unified_attention_args::data_type_enum& data_type); +std::ostream& operator<<(std::ostream& stream, const unified_attention_args& args); + // return value: // first = whether the kernel was launched (true = launched, false = skipped) // second = elapsed time (ms) of the kernel launch, valid only if first == true diff --git a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp index 31e5c4c6ad..4f83be46b8 100644 --- a/example/ck_tile/42_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/42_unified_attention/unified_attention_impl.hpp @@ -5,6 +5,7 @@ #include +#include #include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/container/sequence.hpp" @@ -33,9 +34,10 @@ template <> \ std::pair unified_attention_kernel_dispatch_decode( \ const unified_attention_args& args, const stream_config& config) \ - { \ - return std::make_pair( \ - true, unified_attention_kernel_launch(args, config)); \ + { \ + std::cout << "INST_UNIFIED_ATTENTION_DISPATCH_DECODE, " << args << std::endl; \ + return std::make_pair( \ + true, unified_attention_kernel_launch(args, config)); \ } namespace ck_tile {