mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Log UA argument in the instance heuristic selection
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
#include "mask.hpp"
|
||||
#include <cstdint>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -18,10 +19,75 @@ std::ostream& operator<<(std::ostream& stream,
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void write_ptr(std::ostream& s, const void* p)
|
||||
{
|
||||
if(p == nullptr)
|
||||
{
|
||||
s << "nullptr";
|
||||
return;
|
||||
}
|
||||
const std::ios::fmtflags f = s.flags();
|
||||
s << "0x" << std::hex << reinterpret_cast<std::uintptr_t>(p) << std::dec;
|
||||
s.flags(f);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const unified_attention_args& args)
|
||||
{
|
||||
// Single line, comma-separated key=value (keys match struct fields for easy parsing).
|
||||
stream << "unified_attention_args, "
|
||||
<< " data_type=" << args.data_type
|
||||
<< ", mask_type=" << args.mask_type
|
||||
<< ", num_tokens=" << args.num_tokens
|
||||
<< ", num_blks=" << args.num_blks
|
||||
<< ", num_head_q=" << args.num_head_q
|
||||
<< ", num_queries_per_kv=" << args.num_queries_per_kv
|
||||
<< ", page_blk_size=" << args.page_blk_size
|
||||
<< ", hdim=" << args.hdim
|
||||
<< ", scale_s=" << args.scale_s
|
||||
<< ", scale=" << args.scale
|
||||
<< ", scale_k=" << args.scale_k
|
||||
<< ", scale_v=" << args.scale_v
|
||||
<< ", scale_out=" << args.scale_out
|
||||
// << ", q_ptr=";
|
||||
// write_ptr(stream, args.q_ptr);
|
||||
<< ", query_stride_0=" << args.query_stride_0
|
||||
<< ", query_stride_1=" << args.query_stride_1
|
||||
// << ", k_ptr=";
|
||||
// write_ptr(stream, args.k_ptr);
|
||||
<< ", stride_k_cache_0=" << args.stride_k_cache_0
|
||||
<< ", stride_k_cache_1=" << args.stride_k_cache_1
|
||||
<< ", stride_k_cache_2=" << args.stride_k_cache_2
|
||||
<< ", stride_k_cache_3=" << args.stride_k_cache_3
|
||||
// << ", v_ptr=";
|
||||
// write_ptr(stream, args.v_ptr);
|
||||
<< ", stride_v_cache_0=" << args.stride_v_cache_0
|
||||
<< ", stride_v_cache_1=" << args.stride_v_cache_1
|
||||
<< ", stride_v_cache_2=" << args.stride_v_cache_2
|
||||
<< ", stride_v_cache_3=" << args.stride_v_cache_3
|
||||
// << ", o_ptr=";
|
||||
// write_ptr(stream, args.o_ptr);
|
||||
<< ", output_stride_0=" << args.output_stride_0
|
||||
<< ", output_stride_1=" << args.output_stride_1
|
||||
// << ", block_tables_ptr=";
|
||||
// write_ptr(stream, static_cast<const void*>(args.block_tables_ptr));
|
||||
<< ", block_table_stride=" << args.block_table_stride;
|
||||
// << ", seq_lens_ptr=";
|
||||
// write_ptr(stream, static_cast<const void*>(args.seq_lens_ptr));
|
||||
// stream << ", query_start_len_ptr=";
|
||||
// write_ptr(stream, static_cast<const void*>(args.query_start_len_ptr));
|
||||
return stream << ", num_seqs=" << args.num_seqs
|
||||
<< ", max_seqlen_q=" << args.max_seqlen_q << " }";
|
||||
}
|
||||
|
||||
// Helper macro to reduce dispatch boilerplate.
|
||||
// Dispatches based on DataType, IsMasking, HeadSize, BlockM, NumQPerKV.
|
||||
#define DISPATCH_UNIFIED_ATTENTION(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
std::cout << "DISPATCH_UNIFIED_ATTENTION: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \
|
||||
using kernel_traits = unified_attention_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
@@ -29,18 +95,21 @@ std::ostream& operator<<(std::ostream& stream,
|
||||
// Dispatch macros for three tile tiers (default block_size).
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \
|
||||
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \
|
||||
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_TINY: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \
|
||||
using kernel_traits = unified_attention_decode_tiny_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
@@ -48,18 +117,21 @@ std::ostream& operator<<(std::ostream& stream,
|
||||
// block_size=32 dispatch macros (6th template arg = 32).
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \
|
||||
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \
|
||||
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
std::cout << "DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW: DType=" << DType << " IsMask=" << IsMask << " HSize=" << HSize << " BM=" << BM << " NQPKV=" << NQPKV << std::endl; \
|
||||
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
@@ -72,6 +72,8 @@ struct unified_attention_args
|
||||
std::ostream& operator<<(std::ostream& stream,
|
||||
const unified_attention_args::data_type_enum& data_type);
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, const unified_attention_args& args);
|
||||
|
||||
// return value:
|
||||
// first = whether the kernel was launched (true = launched, false = skipped)
|
||||
// second = elapsed time (ms) of the kernel launch, valid only if first == true
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include <iostream>
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
@@ -33,9 +34,10 @@
|
||||
template <> \
|
||||
std::pair<bool, float> unified_attention_kernel_dispatch_decode<kernel_traits>( \
|
||||
const unified_attention_args& args, const stream_config& config) \
|
||||
{ \
|
||||
return std::make_pair( \
|
||||
true, unified_attention_kernel_launch<kernel_traits::kernel, true>(args, config)); \
|
||||
{ \
|
||||
std::cout << "INST_UNIFIED_ATTENTION_DISPATCH_DECODE, " << args << std::endl; \
|
||||
return std::make_pair( \
|
||||
true, unified_attention_kernel_launch<kernel_traits::kernel, true>(args, config)); \
|
||||
}
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
Reference in New Issue
Block a user