This commit is contained in:
Ding, Yi
2026-03-11 23:03:20 -04:00
commit e6cd3f1e3f
6330 changed files with 1132789 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
# reference
this folder contains reference implementation of a specific op. Note by including a specific header, you are including the implementation(expecially the gpu implementation) into your source code, and compile that kernel into the fatbin, hence may increase your kernel obj code length. Usually the header starts with `reference_` is a cpu reference implementation. The header starts with `naive_` contains a gpu implementation with a small launcher.
TODO: move `host/reference` under this folder

View File

@@ -0,0 +1,95 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include <array>
#include <vector>
namespace ck_tile {
// Helper function to convert std::vector to std::array for kernel parameters
template <ck_tile::index_t NDimSpatial>
inline std::array<ck_tile::long_index_t, NDimSpatial>
to_array(const std::vector<ck_tile::long_index_t>& vec)
{
std::array<ck_tile::long_index_t, NDimSpatial> arr;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
arr[i] = vec[i];
}
return arr;
}
// Helper to fill missing dimensions with default value
template <ck_tile::index_t NDimSpatial>
inline std::array<ck_tile::long_index_t, NDimSpatial>
to_array_with_default(const std::vector<ck_tile::long_index_t>& vec,
ck_tile::long_index_t default_val = 1)
{
std::array<ck_tile::long_index_t, NDimSpatial> arr;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
arr[i] = (static_cast<size_t>(i) < vec.size()) ? vec[i] : default_val;
}
return arr;
}
// Index calculation helpers for GPU reference kernels
namespace detail {
// Calculate linear input index for grouped convolution
// Layout: [N, spatial..., G, C]
template <index_t NDimSpatial>
inline __device__ long_index_t
calculate_input_index(index_t n,
index_t g,
index_t c,
const std::array<index_t, NDimSpatial>& spatial_idx,
const std::array<long_index_t, NDimSpatial + 3>& strides)
{
long_index_t idx = n * strides[0];
for(index_t i = 0; i < NDimSpatial; ++i)
idx += spatial_idx[i] * strides[i + 1];
idx += g * strides[NDimSpatial + 1] + c;
return idx;
}
// Calculate linear weight index for grouped convolution
// Layout: [G, K, spatial..., C]
template <index_t NDimSpatial>
inline __device__ long_index_t
calculate_weight_index(index_t g,
index_t k,
index_t c,
const std::array<index_t, NDimSpatial>& spatial_idx,
const std::array<long_index_t, NDimSpatial + 3>& strides)
{
long_index_t idx = g * strides[0] + k * strides[1];
for(index_t i = 0; i < NDimSpatial; ++i)
idx += spatial_idx[i] * strides[i + 2];
idx += c * strides[NDimSpatial + 2];
return idx;
}
// Calculate linear output index for grouped convolution
// Layout: [N, spatial..., G, K]
template <index_t NDimSpatial>
inline __device__ long_index_t
calculate_output_index(index_t n,
index_t g,
index_t k,
const std::array<index_t, NDimSpatial>& spatial_idx,
const std::array<long_index_t, NDimSpatial + 3>& strides)
{
long_index_t idx = n * strides[0];
for(index_t i = 0; i < NDimSpatial; ++i)
idx += spatial_idx[i] * strides[i + 1];
idx += g * strides[NDimSpatial + 1] + k;
return idx;
}
} // namespace detail
} // namespace ck_tile

View File

@@ -0,0 +1,826 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include <thread>
#include <string>
namespace ck_tile {
enum class naive_attention_layout_enum
{
DEFAULT, // maybe this tensor is not used, set some irrelevant value
BSHD, // [batch, seqlen, nhead, hdim]
BHSD, // [batch, nhead, seqlen, hdim]
BS3HD, // [batch, nhead, 3, seqlen, hdim], used when qkv are packed
PHSD, // [pages, nhead, page_size, hdim]
// PHSDX, // [pages, nhead, page_size/x, hdim, x], where <# used pages>*page_size = seqlen
PHDSX, // [pages, nhead, hdim/x, page_size, x], where <# used pages>*page_size = seqlen
PHDS, // [pages, nhead, hdim, page_size], where <# used pages>*page_size = seqlen
// scale layout used for dynamic dequant
SCALE_HS, // [nhead, tokens] or [nhead, tokens-per-group], nhe KVCache quant
SCALE_SH, // [tokens, nhead]
};
// will used to specialize kernel variation
enum class naive_attention_variation_enum
{
FLASH_BATCHED = 0, // standard flash attention, or xformer/sdpa, used for training
FLASH_GROUPED,
DECODE_PAGED, // decode attn, where kv token from another buffer called kvcache
};
enum class naive_attention_quant_algo
{
NO = 0,
KV_8BIT_PERHEAD = 1,
// FP8/INT8 quant for KVCache, per-token quant
// [num_tokens, nhead, hdim] -> [nhead, num_tokens]
KV_8BIT_PERTOKEN = 2,
};
// TODO: for simplicity, this will be used as host/device arg
struct naive_attention_fwd_args
{
void* q_ptr;
void* k_ptr;
void* v_ptr;
void* o_ptr;
void* context_len_ptr; // [batch] used when seqlen kv come from a pointer(each element is a
// number, not cumsum)
void* page_table_ptr; // [batch, max_pages_per_seq] seqlen_kv is in different block(paged attn)
void* kscale_ptr; // [nhead, max_kv_tokens] used for kvcache dequant
void* vscale_ptr; // [nhead, max_kv_tokens] used for kvcache dequant
float scale_s;
int hdim;
int hdim_v; // could be cross-attn, where V and Q/K hdim are different
int batch_q;
int batch_kv;
int batch_ratio_kv; // batch_q / batch_kv
int seqlen_q; // in decode case, this should be 1
int seqlen_kv; // if context_len_ptr is not nullptr, ignore this field
int nhead_q;
int nhead_kv;
int nhead_ratio_kv; // nhead_q / nhead_kv
int page_size; // if paged, the seqlen-kv per each block
int max_pages_per_seq;
int max_kv_tokens; // used as stride to access kv scale ptr
};
// this is trait for host API
struct naive_attention_fwd_traits
{
std::string q_type;
std::string k_type;
std::string v_type;
std::string o_type;
std::string q_layout;
std::string k_layout;
std::string v_layout;
std::string o_layout;
int variation; // sync with naive_attention_variation_enum
int quant_algo; // sync with naive_attention_quant_algo
};
// this is trait for kernel template
template <naive_attention_variation_enum variation_, naive_attention_quant_algo quant_algo_>
struct naive_attention_fwd_kernel_traits
{
static constexpr naive_attention_variation_enum variation = variation_;
static constexpr naive_attention_quant_algo quant_algo = quant_algo_;
};
// for simplicity, please do not use const-reference type for the template type
template <typename QType,
typename KType,
typename VType,
typename OType,
typename AccType,
typename KVScaleType,
naive_attention_layout_enum QLayout,
naive_attention_layout_enum KLayout,
naive_attention_layout_enum VLayout,
naive_attention_layout_enum OLayout,
naive_attention_layout_enum KScaleLayout,
naive_attention_layout_enum VScaleLayout,
typename Traits>
struct naive_attention_fwd_kernel
{
static constexpr bool is_kvcache_i8 =
std::is_same_v<KType, int8_t> && std::is_same_v<VType, int8_t>;
static constexpr bool is_kvcache_fp8 =
std::is_same_v<KType, fp8_t> && std::is_same_v<VType, fp8_t>;
static constexpr int v_per_token_quant_group_size = 64;
static constexpr int kBlockSize = 256;
// TODO: hardcode
using SoftmaxType = float; // always using float to do softmax compute
using QuantComputeType = float; // used for quant/dequant scale compute
using QCompute = KType; // src A of gemm1, same type as K
using PType = VType; // src A of gemm2, same type as V
using OAccType = float; // always float, in case int8 FA
using p_vec_type = ext_vector_t<PType, 16 / sizeof(PType)>;
static constexpr int p_vec_elem = vector_traits<p_vec_type>::vector_size;
// clang-format off
template <typename T_> struct scale_max { static constexpr float value = 1; /* dummy code */ };
template <> struct scale_max<int8_t> { static constexpr float value = 127.0; };
template <> struct scale_max<fp8_t> { static constexpr float value = 240.0; };
// clang-format on
__host__ __device__ naive_attention_fwd_kernel() {}
template <typename T, naive_attention_layout_enum Layout>
struct addresser
{
int b, s, h, d; // batch, seqlen, nhead, hdim
T* base_ptr;
__device__ addresser(int b_, int s_, int h_, int d_, void* base_ptr_)
: b(b_), s(s_), h(h_), d(d_), base_ptr(reinterpret_cast<T*>(base_ptr_))
{
}
// TODO: all the batch/nhead offset will accumulate to the base pointer
__device__ T* get_base(int i_b, int i_h)
{
if constexpr(Layout == naive_attention_layout_enum::BSHD)
return base_ptr + i_b * s * h * d + i_h * d;
else if constexpr(Layout == naive_attention_layout_enum::BHSD)
return base_ptr + i_b * s * h * d + i_h * s * d;
}
__device__ int get_offset(int i_s, int i_d)
{
if constexpr(Layout == naive_attention_layout_enum::BSHD)
return i_s * h * d + i_d;
else if constexpr(Layout == naive_attention_layout_enum::BHSD)
return i_s * d + i_d;
}
// below set of API will directly use pointer inside this struct
__device__ void init(int i_b, int i_h) { base_ptr = get_base(i_b, i_h); }
__device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; }
__device__ void store(T value, int i_s, int i_d) { base_ptr[get_offset(i_s, i_d)] = value; }
};
template <typename T, naive_attention_layout_enum Layout>
struct page_addresser
{
int s, h, d; // page_size, nhead, hdim
static constexpr int x = 16 / sizeof(T); // pack 4 dword
T* base_ptr;
int* page_table_ptr; // TODO: page table always int
int i_h; // store current head
__device__ page_addresser(int s_, int h_, int d_, void* base_ptr_, void* pptr_)
: s(s_),
h(h_),
d(d_),
base_ptr(reinterpret_cast<T*>(base_ptr_)),
page_table_ptr(reinterpret_cast<int*>(pptr_))
{
}
__device__ int64_t get_phy_page_idx(int i_s)
{
// dynamic compute page idx is simple but slow
int page_idx = i_s / s;
int phy = page_table_ptr[page_idx];
return static_cast<int64_t>(phy);
}
__device__ int get_phy_page_offset(int i_s)
{
// dynamic compute page idx is simple but slow
return i_s % s;
}
__device__ int64_t get_offset(int i_s, int i_d)
{
int page_offset = get_phy_page_offset(i_s);
int64_t page_idx = get_phy_page_idx(i_s);
int64_t base_ = page_idx * h * s * d;
if constexpr(Layout == naive_attention_layout_enum::PHSD)
return static_cast<int64_t>(i_h * s * d + page_offset * d + i_d) + base_;
else if constexpr(Layout == naive_attention_layout_enum::PHDSX)
{
int d_r = i_d / x;
int d_x = i_d % x;
return static_cast<int64_t>(i_h * d * s + d_r * s * x + page_offset * x + d_x) +
base_;
}
else if constexpr(Layout == naive_attention_layout_enum::PHDS)
{
return static_cast<int64_t>(i_h * d * s + i_d * s + page_offset) + base_;
}
}
// below set of API will directly use pointer inside this struct
__device__ void init(int /*i_b*/, int i_h_) { i_h = i_h_; }
__device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; }
__device__ void store(T /*value*/, int /*i_s*/, int /*i_d*/) {}
};
template <typename T, naive_attention_layout_enum Layout>
struct kvscale_addresser
{
int s, h, d; // seqlen(tokens), nhead, hdim
T* base_ptr;
__device__ kvscale_addresser(int s_, int h_, int d_, void* p_)
: s(s_), h(h_), d(d_), base_ptr(reinterpret_cast<T*>(p_))
{
}
__device__ int get_offset(int i_s, int i_h, int i_d)
{
if constexpr(Layout == naive_attention_layout_enum::SCALE_HS)
{
// [nhead, tokens]
(void)i_d;
return i_h * s + i_s;
}
else if constexpr(Layout == naive_attention_layout_enum::DEFAULT)
{
return 0;
}
// [h, 2, d]
// return i_h * 2 * d + i_kv * d + i_d;
}
__device__ T load(int i_s, int i_h, int i_d) { return base_ptr[get_offset(i_s, i_h, i_d)]; }
};
__device__ __host__ static constexpr int get_block_size() { return kBlockSize; }
// for simpliciy, 1 WG always compute 1 token along q, compute all token along kv
// compute all hdim from q, compute WG_SIZE hdim from v
// 1) in prefill case, seqlen_q >= 1, seqlen_kv >= 1, batch_q=batch_kv
// 2) in decode case, seqlen_q = 1, batch_q is input num-tokens, batch_kv is 1
// 3) in paged-attn case, we still use 1 WG compute all the seqlen-kv for simplicity
// TODO: could support split-kv to validate intermediate logsum
__host__ static dim3 get_grid_size(naive_attention_fwd_args args)
{
constexpr int wg_size = get_block_size();
auto g =
dim3((args.hdim_v + wg_size - 1) / wg_size, args.seqlen_q, args.batch_q * args.nhead_q);
return g;
}
// reduce single pixel within a wave
template <typename T, typename F>
__device__ constexpr T wave_reduce(T local, F reduce_f)
{
// constexpr int wave_size = 64;
constexpr int reduce_stage = 6; // 1<<6=64
T v_local = local;
#pragma unroll
for(int i_stage = 0; i_stage < reduce_stage; i_stage++)
{
int src_lane = __lane_id() ^ (1 << i_stage);
int32_t v_remote_tmp =
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
T v_remote = bit_cast<T>(v_remote_tmp);
v_local = reduce_f(v_local, v_remote);
}
return v_local;
}
// Note: this function must be called after wave_reduce
// Note: better not use this under if...else... with thread divergence (syncthreads)
template <typename T, typename F>
__device__ constexpr T cross_wave_reduce(T local, F reduce_f, T* smem)
{
constexpr int waves = 4;
constexpr int wave_size = 64;
int lane_id = threadIdx.x % wave_size;
__syncthreads();
smem[threadIdx.x] = local;
__syncthreads();
// the data within single wave is the same
// but for simplicity, we still use data from each lane.
T v_local = smem[lane_id];
#pragma unroll
for(int i_stage = 1; i_stage < waves; i_stage++)
{
T v_remote = smem[i_stage * wave_size + lane_id];
v_local = reduce_f(v_local, v_remote);
}
return v_local;
}
// kernel entry point
__device__ void operator()(naive_attention_fwd_args args)
{
constexpr int wg_size = get_block_size();
__shared__ char smem[wg_size * 4 * sizeof(float)]; // should enough
char* smem_quant_q = smem + wg_size * 2 * sizeof(float); // second half, should enough
int i_dv = blockIdx.x * wg_size + threadIdx.x; // index of hdim_v
int i_sq = blockIdx.y; // index of seqlen_q
int i_batch = blockIdx.z; // index of batch_q * nhead_q
int i_bq = i_batch / args.nhead_q; // index of batch_q
int i_hq = i_batch % args.nhead_q; // index of nhead_q
int i_bk = i_bq / args.batch_ratio_kv;
int i_hk = i_hq / args.nhead_ratio_kv;
void* page_table_ptr = [&]() {
if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
{
return reinterpret_cast<int*>(args.page_table_ptr) + i_bq * args.max_pages_per_seq;
}
else
{
return nullptr;
}
}();
auto q_addr = [&]() {
if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
{
return addresser<QType, QLayout>{
args.batch_q, args.seqlen_q, args.nhead_q, args.hdim, args.q_ptr};
}
else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
{
return addresser<QType, QLayout>{
args.batch_q, args.seqlen_q, args.nhead_q, args.hdim, args.q_ptr};
}
}();
auto k_addr = [&]() {
if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
{
return addresser<KType, KLayout>{
args.batch_kv, args.seqlen_kv, args.nhead_kv, args.hdim, args.k_ptr};
}
else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
{
return page_addresser<KType, KLayout>{
args.page_size, args.nhead_kv, args.hdim, args.k_ptr, page_table_ptr};
}
}();
auto v_addr = [&]() {
if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
{
return addresser<VType, VLayout>{
args.batch_kv, args.seqlen_kv, args.nhead_kv, args.hdim_v, args.v_ptr};
}
else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
{
return page_addresser<VType, VLayout>{
args.page_size, args.nhead_kv, args.hdim_v, args.v_ptr, page_table_ptr};
}
}();
auto o_addr = [&]() {
if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
{
return addresser<OType, OLayout>{
args.batch_q, args.seqlen_q, args.nhead_q, args.hdim_v, args.o_ptr};
}
else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
{
return addresser<OType, OLayout>{
args.batch_q, args.seqlen_q, args.nhead_q, args.hdim_v, args.o_ptr};
}
}();
q_addr.init(i_bq, i_hq);
k_addr.init(i_bk, i_hk);
v_addr.init(i_bk, i_hk);
o_addr.init(i_bq, i_hq);
auto f_max = [](auto x_, auto y_) { return max(x_, y_); };
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
auto f_absmax_f32 = [](float v_0_, float v_1_) {
// float rtn;
// asm volatile("v_max_f32 %0, abs(%1), abs(%2)" : "=v"(rtn) : "v"(v_0_), "v"(v_1_));
// return rtn;
return max(abs(v_0_), abs(v_1_));
};
int seqlen_kv = [&]() {
if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
{
return args.seqlen_kv;
}
else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
{
return reinterpret_cast<int*>(args.context_len_ptr)[i_bq];
}
}();
SoftmaxType row_max = -numeric<SoftmaxType>::infinity();
SoftmaxType l{0};
// AccType o_acc = {0};
OAccType o_acc = {0};
int sk_loops = (seqlen_kv + wg_size - 1) / wg_size;
QuantComputeType q_dequant_scale = .0f;
kvscale_addresser<KVScaleType, KScaleLayout> kscale_addr{
args.max_kv_tokens, args.nhead_kv, args.hdim, args.kscale_ptr};
kvscale_addresser<KVScaleType, VScaleLayout> vscale_addr{
args.max_kv_tokens, args.nhead_kv, args.hdim_v, args.vscale_ptr};
if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
{
// AccType is i32 now, seqlen_q = 1, hdim up to 256
AccType q = 0;
AccType k_s = 0;
if(static_cast<int>(threadIdx.x) < args.hdim)
{
q = type_convert<AccType>(q_addr.load(0, threadIdx.x));
k_s = type_convert<AccType>(kscale_addr.load(i_hk, threadIdx.x, 0));
}
// 1) we apply the k scale to q
AccType q_forwarded = q * k_s;
// 2) apply smooth-quant
// find absmax
AccType qf_max = wave_reduce(q_forwarded, f_absmax_f32);
qf_max = cross_wave_reduce(qf_max, f_absmax_f32, reinterpret_cast<AccType*>(smem));
// per-token scale
q_dequant_scale = type_convert<QuantComputeType>(qf_max) / scale_max<QCompute>::value;
// devide by scale
q = q / q_dequant_scale;
// fp32->i8
QCompute quantized_q = static_cast<QCompute>(q);
__syncthreads();
reinterpret_cast<QCompute*>(smem)[threadIdx.x] = quantized_q;
__syncthreads();
// after above process, we have 2 data
// 1) int8 q data stored in smem(no need to reload)
// 2) per-token scale q_dequant_scale, to be mul after 1st gemm
}
else if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERTOKEN)
{
if(std::is_same_v<QType, fp16_t> || std::is_same_v<QType, bf16_t>)
{
// dyanmic quant q here
float q = 0;
if(static_cast<int>(threadIdx.x) < args.hdim)
{
q = type_convert<float>(q_addr.load(i_sq, threadIdx.x));
}
// apply smooth-quant
// find absmax
float q_max = wave_reduce(q, f_absmax_f32);
q_max = cross_wave_reduce(q_max, f_absmax_f32, reinterpret_cast<float*>(smem));
// per-token scale
q_dequant_scale =
type_convert<QuantComputeType>(q_max) / scale_max<QCompute>::value;
// devide by scale
q = q / q_dequant_scale;
QCompute quantized_q = type_convert<QCompute>(q);
__syncthreads();
reinterpret_cast<QCompute*>(smem_quant_q)[threadIdx.x] = quantized_q;
__syncthreads();
// after above process, we have 2 data
// 1) fp8 q data stored in smem(no need to reload from global)
// 2) per-token scale q_dequant_scale, to be mul after 1st gemm
}
}
for(int i_loop1 = 0; i_loop1 < sk_loops; i_loop1++)
{
int i_sk = i_loop1 * wg_size + threadIdx.x;
// gemm-1
SoftmaxType s_softmax = -numeric<SoftmaxType>::infinity();
if(i_sk < seqlen_kv)
{
AccType s_acc{0}; // clear for every loop
for(auto i_dq = 0; i_dq < args.hdim; i_dq++)
{
auto q = [&]() {
if constexpr(Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERHEAD ||
Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERTOKEN)
{
return reinterpret_cast<QCompute*>(smem_quant_q)[i_dq];
}
else
return q_addr.load(i_sq, i_dq); // q will have duplicate load
}();
auto k = [&]() { return k_addr.load(i_sk, i_dq); }();
s_acc += type_convert<AccType>(q) * type_convert<AccType>(k);
}
// scale
s_softmax = type_convert<SoftmaxType>(s_acc);
s_softmax *=
type_convert<SoftmaxType>(args.scale_s * ck_tile::log2e_v<SoftmaxType>);
if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
{
s_softmax *= q_dequant_scale; // post scale the per-token factor
}
else if constexpr(Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERTOKEN)
{
SoftmaxType k_per_token_scale =
type_convert<SoftmaxType>(kscale_addr.load(i_sk, i_hk, 0));
s_softmax *= q_dequant_scale;
s_softmax *= k_per_token_scale;
}
}
// s->p
QuantComputeType p_dequant_scale = 1.;
{
// softmax, find max
SoftmaxType old_max = row_max;
SoftmaxType cur_max = wave_reduce(s_softmax, f_max);
cur_max = cross_wave_reduce(cur_max, f_max, reinterpret_cast<SoftmaxType*>(smem));
row_max = max(old_max, cur_max); // update row_max
// softmax, exp(i_elem - max)
SoftmaxType p_compute = __builtin_amdgcn_exp2f(s_softmax - row_max);
// compute exp_sum
SoftmaxType row_sum = wave_reduce(p_compute, f_sum);
row_sum = cross_wave_reduce(row_sum, f_sum, reinterpret_cast<SoftmaxType*>(smem));
// l, pre-scall o_acc
SoftmaxType tmp = __builtin_amdgcn_exp2f(old_max - row_max);
l = tmp * l + row_sum;
o_acc = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp);
// prepare the p_compute into smem, to let every thread read same p_compute and do
// 2nd gemm
if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
{
QuantComputeType v_s = 0;
if(static_cast<int>(threadIdx.x) < args.hdim_v)
{
v_s =
type_convert<QuantComputeType>(vscale_addr.load(i_hk, threadIdx.x, 1));
}
// 1) we apply the v scale to p
QuantComputeType p_forwarded = p_compute * v_s;
// 2) apply smooth-quant
// find absmax
QuantComputeType pf_max = wave_reduce(p_forwarded, f_absmax_f32);
pf_max = cross_wave_reduce(
pf_max, f_absmax_f32, reinterpret_cast<QuantComputeType*>(smem));
// per-token scale
p_dequant_scale = pf_max / scale_max<PType>::value; // 127.0;
// devide by scale
p_compute = p_compute / p_dequant_scale;
// fp32->i8
PType quantized_p = static_cast<PType>(p_compute);
__syncthreads();
reinterpret_cast<PType*>(smem)[threadIdx.x] = quantized_p;
__syncthreads();
// after above process, we have 2 data
// 1) int8 p data stored in smem(no need to reload)
// 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
}
else if constexpr(Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERTOKEN)
{
// forward apply the v scale to p_compute, this is compute friendly
auto v_scale = type_convert<QuantComputeType>(vscale_addr.load(i_sk, i_hk, 0));
p_compute *= v_scale;
// smooth-quant
// find absmax
QuantComputeType p_max = wave_reduce(p_compute, f_absmax_f32);
p_max = cross_wave_reduce(
p_max, f_absmax_f32, reinterpret_cast<QuantComputeType*>(smem));
// per-token scale
p_dequant_scale = p_max / scale_max<PType>::value; // 240.0;
// devide by scale
p_compute = p_compute / p_dequant_scale;
// fp32->i8
PType quantized_p = type_convert<PType>(p_compute);
__syncthreads();
reinterpret_cast<PType*>(smem)[threadIdx.x] = quantized_p;
__syncthreads();
// after above process, we have 2 data
// 1) fp8_t p data stored in smem(no need to reload)
// 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
}
else
{
__syncthreads();
reinterpret_cast<PType*>(smem)[threadIdx.x] = type_convert<PType>(p_compute);
__syncthreads();
}
}
// gemm-2, simple loop over vector by vector
constexpr int gemm_2_loop = wg_size / p_vec_elem;
{
AccType o_acc_local = {0};
int sk_start = i_loop1 * wg_size; // we start from the first seqlen_kv element
for(int i_loop2 = 0; i_loop2 < gemm_2_loop; i_loop2++)
{
p_vec_type p_vec = reinterpret_cast<p_vec_type*>(smem)[i_loop2];
#pragma unroll
for(int i_j = 0; i_j < p_vec_elem; i_j++)
{
int sv_offset = i_loop2 * p_vec_elem + i_j;
int i_sv = sk_start + sv_offset;
VType v = 0;
if(i_dv < args.hdim_v && i_sv < seqlen_kv)
{
v = v_addr.load(i_sv, i_dv);
}
AccType v_compute = [&]() { return type_convert<AccType>(v); }();
o_acc_local += type_convert<AccType>(p_vec[i_j]) * v_compute;
}
}
OAccType post_scale_o_acc_local = [&]() {
if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
{
// apply pr scale to local acc
return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) *
p_dequant_scale);
}
else if constexpr(Traits::quant_algo ==
naive_attention_quant_algo::KV_8BIT_PERTOKEN)
{
// apply pr scale to local acc
return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) *
p_dequant_scale);
}
else
{
return type_convert<OAccType>(o_acc_local);
}
}();
o_acc += post_scale_o_acc_local;
}
}
// post scale o_acc
{
SoftmaxType tmp = l == 0.f ? 0.f : 1.f / l; // in case masking
o_acc = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp);
}
// store O
if(i_dv < args.hdim_v)
o_addr.store(type_convert<OType>(o_acc), i_sq, i_dv);
}
};
#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_() \
{ \
using ktraits_ = naive_attention_fwd_kernel_traits< \
static_cast<naive_attention_variation_enum>(variation_), \
static_cast<naive_attention_quant_algo>(quant_algo_)>; \
using k_ = naive_attention_fwd_kernel<q_type_, \
k_type_, \
v_type_, \
o_type_, \
acc_type_, \
kvscale_type_, \
q_layout_, \
k_layout_, \
v_layout_, \
o_layout_, \
k_scale_layout_, \
v_scale_layout_, \
ktraits_>; \
dim3 grids = k_::get_grid_size(a); \
r = ck_tile::launch_kernel(s, \
ck_tile::make_kernel(k_{}, grids, k_::get_block_size(), 0, a)); \
}
#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_() \
if(t.variation == 0 && t.q_layout == "bshd" && t.k_layout == "bshd" && t.v_layout == "bshd" && \
t.o_layout == "bshd") \
{ \
constexpr auto q_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto k_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto v_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto o_layout_ = naive_attention_layout_enum::BSHD; \
constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
constexpr int variation_ = 0; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
} \
else if(t.variation == 0 && t.q_layout == "bhsd" && t.k_layout == "bhsd" && \
t.v_layout == "bhsd" && t.o_layout == "bhsd") \
{ \
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto v_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
constexpr int variation_ = 0; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
} \
else if(t.variation == 2 && t.q_layout == "bhsd" && t.k_layout == "phdsx" && \
t.v_layout == "phds" && t.o_layout == "bhsd") \
{ \
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_layout_ = naive_attention_layout_enum::PHDSX; \
constexpr auto v_layout_ = naive_attention_layout_enum::PHDS; \
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
constexpr auto k_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \
constexpr auto v_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \
constexpr int variation_ = 2; \
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
}
//
CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
naive_attention_fwd_args a,
ck_tile::stream_config s)
{
float r = -1;
// TODO: do not explicitly create too much instance!
if(t.q_type == "fp16" && t.k_type == "fp16" && t.v_type == "fp16" && t.o_type == "fp16" &&
t.quant_algo == 0)
{
using q_type_ = fp16_t;
using k_type_ = fp16_t;
using v_type_ = fp16_t;
using o_type_ = fp16_t;
using acc_type_ = float;
using kvscale_type_ = float;
constexpr int quant_algo_ = 0;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
}
else if(t.q_type == "bf16" && t.k_type == "bf16" && t.v_type == "bf16" && t.o_type == "bf16" &&
t.quant_algo == 0)
{
using q_type_ = bf16_t;
using k_type_ = bf16_t;
using v_type_ = bf16_t;
using o_type_ = bf16_t;
using acc_type_ = float;
using kvscale_type_ = float;
constexpr int quant_algo_ = 0;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
}
else if(t.q_type == "bf16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "bf16" &&
t.quant_algo == 2)
{
using q_type_ = bf16_t;
using k_type_ = fp8_t;
using v_type_ = fp8_t;
using o_type_ = bf16_t;
using acc_type_ = float; // NOTE!
using kvscale_type_ = float;
constexpr int quant_algo_ = 2;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
}
else if(t.q_type == "fp16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "fp16" &&
t.quant_algo == 2)
{
using q_type_ = fp16_t;
using k_type_ = fp8_t;
using v_type_ = fp8_t;
using o_type_ = fp16_t;
using acc_type_ = float; // NOTE!
using kvscale_type_ = float;
constexpr int quant_algo_ = 2;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
}
else if(t.q_type == "bf16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "bf16" &&
t.quant_algo == 2)
{
using q_type_ = bf16_t;
using k_type_ = int8_t;
using v_type_ = int8_t;
using o_type_ = bf16_t;
using acc_type_ = int32_t; // NOTE!
using kvscale_type_ = float;
constexpr int quant_algo_ = 2;
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
}
return r;
}
#undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
#undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_
} // namespace ck_tile

View File

@@ -0,0 +1,360 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ref/conv_common.hpp"
#include <array>
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include <hip/hip_runtime.h>
namespace ck_tile {
// Naive GPU reference kernel struct for backward data grouped convolution
// Computes gradient with respect to input
// Layout: Input_grad=NDHWGC, Weight=GKZYXC, Output_grad=NDHWGK (for 3D case)
// Input_grad=NHWGC, Weight=GKYXC, Output_grad=NHWGK (for 2D case)
// Input_grad=NWGC, Weight=GKXC, Output_grad=NWGK (for 1D case)
//
// One thread per input element, uses grid-stride loop pattern
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct naive_grouped_conv_bwd_data_kernel
{
static constexpr ck_tile::index_t kBlockSize = 256;
__device__ void
operator()(InDataType* __restrict__ p_in_grad,
const WeiDataType* __restrict__ p_wei,
const OutDataType* __restrict__ p_out_grad,
// Tensor dimensions
ck_tile::index_t G, // number of groups
ck_tile::index_t N, // batch size
ck_tile::index_t K, // output channels per group
ck_tile::index_t C, // input channels per group
// Input spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& in_spatial_lengths,
// Weight spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& wei_spatial_lengths,
// Output spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& out_spatial_lengths,
// Convolution parameters
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_strides,
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_dilations,
const std::array<ck_tile::long_index_t, NDimSpatial>& in_left_pads) const
{
const ck_tile::long_index_t tid = get_block_id() * blockDim.x + get_thread_id();
const ck_tile::long_index_t num_threads = blockDim.x * gridDim.x;
// Calculate total input elements
ck_tile::long_index_t input_length = G * N * C;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
input_length *= in_spatial_lengths[i];
}
// Calculate strides for input tensor (NDHWGC or NHWGC or NWGC)
std::array<ck_tile::long_index_t, NDimSpatial + 3> in_strides;
ck_tile::long_index_t stride = 1;
in_strides[NDimSpatial + 2] = stride; // C stride
stride *= C;
in_strides[NDimSpatial + 1] = stride; // G stride
stride *= G;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
{
in_strides[i + 1] = stride;
stride *= in_spatial_lengths[i];
}
in_strides[0] = stride; // N stride
// Calculate strides for output tensor (NDHWGK or NHWGK or NWGK)
std::array<ck_tile::long_index_t, NDimSpatial + 3> out_strides;
stride = 1;
out_strides[NDimSpatial + 2] = stride; // K stride
stride *= K;
out_strides[NDimSpatial + 1] = stride; // G stride
stride *= G;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
{
out_strides[i + 1] = stride;
stride *= out_spatial_lengths[i];
}
out_strides[0] = stride; // N stride
// Calculate strides for weight tensor (GKZYXC or GKYXC or GKXC)
std::array<ck_tile::long_index_t, NDimSpatial + 3> wei_strides;
stride = 1;
wei_strides[NDimSpatial + 2] = stride; // C stride
stride *= C;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
{
wei_strides[i + 2] = stride;
stride *= wei_spatial_lengths[i];
}
wei_strides[1] = stride; // K stride
stride *= K;
wei_strides[0] = stride; // G stride
// Grid-stride loop over all input elements
for(ck_tile::long_index_t ii = tid; ii < input_length; ii += num_threads)
{
// Decode linear index to multi-dimensional indices
ck_tile::long_index_t tmp = ii;
// Extract N (batch)
ck_tile::index_t n = tmp / in_strides[0];
tmp -= n * in_strides[0];
// Extract spatial dimensions
ck_tile::index_t in_spatial_idx[6];
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
in_spatial_idx[i] = tmp / in_strides[i + 1];
tmp -= in_spatial_idx[i] * in_strides[i + 1];
}
// Extract G (group)
ck_tile::index_t g = tmp / in_strides[NDimSpatial + 1];
tmp -= g * in_strides[NDimSpatial + 1];
// Extract C (input channel)
ck_tile::index_t c = tmp;
// Accumulate in float
float v_acc = 0.0f;
// Loop over output channels
for(ck_tile::index_t k = 0; k < K; ++k)
{
// Loop over filter spatial dimensions
if constexpr(NDimSpatial == 1)
{
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[0]; ++x)
{
// Calculate output spatial coordinate (inverse of forward)
ck_tile::long_index_t w_tmp =
static_cast<ck_tile::long_index_t>(in_spatial_idx[0]) +
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]);
// Check if this maps to valid output position
if(w_tmp % conv_strides[0] == 0)
{
ck_tile::long_index_t wo = w_tmp / conv_strides[0];
if(wo >= 0 && wo < out_spatial_lengths[0])
{
std::array<ck_tile::index_t, 1> out_spatial = {
static_cast<index_t>(wo)};
std::array<ck_tile::index_t, 1> wei_spatial = {x};
ck_tile::long_index_t out_idx = detail::calculate_output_index<1>(
n, g, k, out_spatial, out_strides);
ck_tile::long_index_t wei_idx = detail::calculate_weight_index<1>(
g, k, c, wei_spatial, wei_strides);
v_acc += type_convert<float>(p_out_grad[out_idx]) *
type_convert<float>(p_wei[wei_idx]);
}
}
}
}
else if constexpr(NDimSpatial == 2)
{
for(ck_tile::index_t y = 0; y < wei_spatial_lengths[0]; ++y)
{
ck_tile::long_index_t h_tmp =
static_cast<ck_tile::long_index_t>(in_spatial_idx[0]) +
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]);
if(h_tmp % conv_strides[0] == 0)
{
ck_tile::long_index_t ho = h_tmp / conv_strides[0];
if(ho >= 0 && ho < out_spatial_lengths[0])
{
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[1]; ++x)
{
ck_tile::long_index_t w_tmp =
static_cast<ck_tile::long_index_t>(in_spatial_idx[1]) +
static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]);
if(w_tmp % conv_strides[1] == 0)
{
ck_tile::long_index_t wo = w_tmp / conv_strides[1];
if(wo >= 0 && wo < out_spatial_lengths[1])
{
std::array<ck_tile::index_t, 2> out_spatial = {
static_cast<index_t>(ho), static_cast<index_t>(wo)};
std::array<ck_tile::index_t, 2> wei_spatial = {y, x};
ck_tile::long_index_t out_idx =
detail::calculate_output_index<2>(
n, g, k, out_spatial, out_strides);
ck_tile::long_index_t wei_idx =
detail::calculate_weight_index<2>(
g, k, c, wei_spatial, wei_strides);
v_acc += type_convert<float>(p_out_grad[out_idx]) *
type_convert<float>(p_wei[wei_idx]);
}
}
}
}
}
}
}
else if constexpr(NDimSpatial == 3)
{
for(ck_tile::index_t z = 0; z < wei_spatial_lengths[0]; ++z)
{
ck_tile::long_index_t d_tmp =
static_cast<ck_tile::long_index_t>(in_spatial_idx[0]) +
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]);
if(d_tmp % conv_strides[0] == 0)
{
ck_tile::long_index_t do_ = d_tmp / conv_strides[0];
if(do_ >= 0 && do_ < out_spatial_lengths[0])
{
for(ck_tile::index_t y = 0; y < wei_spatial_lengths[1]; ++y)
{
ck_tile::long_index_t h_tmp =
static_cast<ck_tile::long_index_t>(in_spatial_idx[1]) +
static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]);
if(h_tmp % conv_strides[1] == 0)
{
ck_tile::long_index_t ho = h_tmp / conv_strides[1];
if(ho >= 0 && ho < out_spatial_lengths[1])
{
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[2];
++x)
{
ck_tile::long_index_t w_tmp =
static_cast<ck_tile::long_index_t>(
in_spatial_idx[2]) +
static_cast<ck_tile::long_index_t>(
in_left_pads[2]) -
static_cast<ck_tile::long_index_t>(
x * conv_dilations[2]);
if(w_tmp % conv_strides[2] == 0)
{
ck_tile::long_index_t wo =
w_tmp / conv_strides[2];
if(wo >= 0 && wo < out_spatial_lengths[2])
{
std::array<ck_tile::index_t, 3>
out_spatial = {
static_cast<index_t>(do_),
static_cast<index_t>(ho),
static_cast<index_t>(wo)};
std::array<ck_tile::index_t, 3>
wei_spatial = {z, y, x};
ck_tile::long_index_t out_idx =
detail::calculate_output_index<3>(
n, g, k, out_spatial, out_strides);
ck_tile::long_index_t wei_idx =
detail::calculate_weight_index<3>(
g, k, c, wei_spatial, wei_strides);
v_acc +=
type_convert<float>(
p_out_grad[out_idx]) *
type_convert<float>(p_wei[wei_idx]);
}
}
}
}
}
}
}
}
}
}
}
// Convert accumulator to output type and write
p_in_grad[ii] = type_convert<InDataType>(v_acc);
}
}
};
// Host-side launcher for naive grouped convolution backward data
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
CK_TILE_HOST float
naive_grouped_conv_bwd_data(InDataType* p_in_grad_dev,
const WeiDataType* p_wei_dev,
const OutDataType* p_out_grad_dev,
ck_tile::index_t G,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t C,
std::vector<ck_tile::long_index_t> in_spatial_lengths,
std::vector<ck_tile::long_index_t> wei_spatial_lengths,
std::vector<ck_tile::long_index_t> out_spatial_lengths,
std::vector<ck_tile::long_index_t> conv_strides,
std::vector<ck_tile::long_index_t> conv_dilations,
std::vector<ck_tile::long_index_t> in_left_pads,
ck_tile::stream_config stream_config = {})
{
// Convert vectors to arrays
auto in_spatial_arr = to_array_with_default<NDimSpatial>(in_spatial_lengths);
auto wei_spatial_arr = to_array_with_default<NDimSpatial>(wei_spatial_lengths);
auto out_spatial_arr = to_array_with_default<NDimSpatial>(out_spatial_lengths);
auto conv_strides_arr = to_array_with_default<NDimSpatial>(conv_strides);
auto conv_dilations_arr = to_array_with_default<NDimSpatial>(conv_dilations);
auto in_left_pads_arr = to_array_with_default<NDimSpatial>(in_left_pads, 0);
// Calculate grid size
ck_tile::long_index_t input_length = G * N * C;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
input_length *= in_spatial_lengths[i];
}
using KernelType =
naive_grouped_conv_bwd_data_kernel<NDimSpatial, InDataType, WeiDataType, OutDataType>;
constexpr ck_tile::index_t block_size = KernelType::kBlockSize;
const ck_tile::index_t grid_size = (input_length + block_size - 1) / block_size;
// Launch kernel
float elapsed_ms = launch_kernel(stream_config,
make_kernel(KernelType{},
dim3(grid_size),
dim3(block_size),
0, // dynamic shared memory size
p_in_grad_dev,
p_wei_dev,
p_out_grad_dev,
G,
N,
K,
C,
in_spatial_arr,
wei_spatial_arr,
out_spatial_arr,
conv_strides_arr,
conv_dilations_arr,
in_left_pads_arr));
return elapsed_ms;
}
} // namespace ck_tile

View File

@@ -0,0 +1,324 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ref/conv_common.hpp"
#include <array>
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include <hip/hip_runtime.h>
namespace ck_tile {
// Naive GPU reference kernel struct for backward weight grouped convolution
// Computes gradient with respect to weights
// Layout: Input=NDHWGC, Weight_grad=GKZYXC, Output_grad=NDHWGK (for 3D case)
// Input=NHWGC, Weight_grad=GKYXC, Output_grad=NHWGK (for 2D case)
// Input=NWGC, Weight_grad=GKXC, Output_grad=NWGK (for 1D case)
//
// One thread per weight element, uses grid-stride loop pattern
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct naive_grouped_conv_bwd_weight_kernel
{
static constexpr ck_tile::index_t kBlockSize = 256;
__device__ void
operator()(const InDataType* __restrict__ p_in,
WeiDataType* __restrict__ p_wei_grad,
const OutDataType* __restrict__ p_out_grad,
// Tensor dimensions
ck_tile::index_t G, // number of groups
ck_tile::index_t N, // batch size
ck_tile::index_t K, // output channels per group
ck_tile::index_t C, // input channels per group
// Input spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& in_spatial_lengths,
// Weight spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& wei_spatial_lengths,
// Output spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& out_spatial_lengths,
// Convolution parameters
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_strides,
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_dilations,
const std::array<ck_tile::long_index_t, NDimSpatial>& in_left_pads) const
{
const ck_tile::long_index_t tid = get_block_id() * blockDim.x + get_thread_id();
const ck_tile::long_index_t num_threads = blockDim.x * gridDim.x;
// Calculate total weight elements
ck_tile::long_index_t weight_length = G * K * C;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
weight_length *= wei_spatial_lengths[i];
}
// Calculate strides for weight tensor (GKZYXC or GKYXC or GKXC)
std::array<ck_tile::long_index_t, NDimSpatial + 3> wei_strides;
ck_tile::long_index_t stride = 1;
wei_strides[NDimSpatial + 2] = stride; // C stride
stride *= C;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
{
wei_strides[i + 2] = stride;
stride *= wei_spatial_lengths[i];
}
wei_strides[1] = stride; // K stride
stride *= K;
wei_strides[0] = stride; // G stride
// Calculate strides for input tensor (NDHWGC or NHWGC or NWGC)
std::array<ck_tile::long_index_t, NDimSpatial + 3> in_strides;
stride = 1;
in_strides[NDimSpatial + 2] = stride; // C stride
stride *= C;
in_strides[NDimSpatial + 1] = stride; // G stride
stride *= G;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
{
in_strides[i + 1] = stride;
stride *= in_spatial_lengths[i];
}
in_strides[0] = stride; // N stride
// Calculate strides for output tensor (NDHWGK or NHWGK or NWGK)
std::array<ck_tile::long_index_t, NDimSpatial + 3> out_strides;
stride = 1;
out_strides[NDimSpatial + 2] = stride; // K stride
stride *= K;
out_strides[NDimSpatial + 1] = stride; // G stride
stride *= G;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
{
out_strides[i + 1] = stride;
stride *= out_spatial_lengths[i];
}
out_strides[0] = stride; // N stride
// Grid-stride loop over all weight elements
for(ck_tile::long_index_t ii = tid; ii < weight_length; ii += num_threads)
{
// Decode linear index to multi-dimensional indices
ck_tile::long_index_t tmp = ii;
// Extract G (group)
ck_tile::index_t g = tmp / wei_strides[0];
tmp -= g * wei_strides[0];
// Extract K (output channel)
ck_tile::index_t k = tmp / wei_strides[1];
tmp -= k * wei_strides[1];
// Extract spatial dimensions (come before C in GKZYXC layout)
ck_tile::index_t wei_spatial_idx[6];
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
wei_spatial_idx[i] = tmp / wei_strides[i + 2];
tmp -= wei_spatial_idx[i] * wei_strides[i + 2];
}
// Extract C (input channel) - comes last
ck_tile::index_t c = tmp;
// Accumulate in float
float v_acc = 0.0f;
// Loop over batch
for(ck_tile::index_t n = 0; n < N; ++n)
{
// Loop over output spatial dimensions
if constexpr(NDimSpatial == 1)
{
for(ck_tile::index_t wo = 0; wo < out_spatial_lengths[0]; ++wo)
{
// Calculate input spatial coordinate
ck_tile::long_index_t wi =
static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(wei_spatial_idx[0] *
conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
// Bounds check
if(wi >= 0 && wi < in_spatial_lengths[0])
{
std::array<ck_tile::index_t, 1> in_spatial = {static_cast<index_t>(wi)};
std::array<ck_tile::index_t, 1> out_spatial = {
static_cast<index_t>(wo)};
ck_tile::long_index_t in_idx =
detail::calculate_input_index<1>(n, g, c, in_spatial, in_strides);
ck_tile::long_index_t out_idx = detail::calculate_output_index<1>(
n, g, k, out_spatial, out_strides);
v_acc += type_convert<float>(p_out_grad[out_idx]) *
type_convert<float>(p_in[in_idx]);
}
}
}
else if constexpr(NDimSpatial == 2)
{
for(ck_tile::index_t ho = 0; ho < out_spatial_lengths[0]; ++ho)
{
ck_tile::long_index_t hi =
static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(wei_spatial_idx[0] *
conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
for(ck_tile::index_t wo = 0; wo < out_spatial_lengths[1]; ++wo)
{
ck_tile::long_index_t wi =
static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
static_cast<ck_tile::long_index_t>(wei_spatial_idx[1] *
conv_dilations[1]) -
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
// Bounds check
if(hi >= 0 && hi < in_spatial_lengths[0] && wi >= 0 &&
wi < in_spatial_lengths[1])
{
std::array<ck_tile::index_t, 2> in_spatial = {
static_cast<index_t>(hi), static_cast<index_t>(wi)};
std::array<ck_tile::index_t, 2> out_spatial = {
static_cast<index_t>(ho), static_cast<index_t>(wo)};
ck_tile::long_index_t in_idx = detail::calculate_input_index<2>(
n, g, c, in_spatial, in_strides);
ck_tile::long_index_t out_idx = detail::calculate_output_index<2>(
n, g, k, out_spatial, out_strides);
v_acc += type_convert<float>(p_out_grad[out_idx]) *
type_convert<float>(p_in[in_idx]);
}
}
}
}
else if constexpr(NDimSpatial == 3)
{
for(ck_tile::index_t do_ = 0; do_ < out_spatial_lengths[0]; ++do_)
{
ck_tile::long_index_t di =
static_cast<ck_tile::long_index_t>(do_ * conv_strides[0]) +
static_cast<ck_tile::long_index_t>(wei_spatial_idx[0] *
conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
for(ck_tile::index_t ho = 0; ho < out_spatial_lengths[1]; ++ho)
{
ck_tile::long_index_t hi =
static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
static_cast<ck_tile::long_index_t>(wei_spatial_idx[1] *
conv_dilations[1]) -
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
for(ck_tile::index_t wo = 0; wo < out_spatial_lengths[2]; ++wo)
{
ck_tile::long_index_t wi =
static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
static_cast<ck_tile::long_index_t>(wei_spatial_idx[2] *
conv_dilations[2]) -
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
// Bounds check
if(di >= 0 && di < in_spatial_lengths[0] && hi >= 0 &&
hi < in_spatial_lengths[1] && wi >= 0 &&
wi < in_spatial_lengths[2])
{
std::array<ck_tile::index_t, 3> in_spatial = {
static_cast<index_t>(di),
static_cast<index_t>(hi),
static_cast<index_t>(wi)};
std::array<ck_tile::index_t, 3> out_spatial = {
static_cast<index_t>(do_),
static_cast<index_t>(ho),
static_cast<index_t>(wo)};
ck_tile::long_index_t in_idx = detail::calculate_input_index<3>(
n, g, c, in_spatial, in_strides);
ck_tile::long_index_t out_idx =
detail::calculate_output_index<3>(
n, g, k, out_spatial, out_strides);
v_acc += type_convert<float>(p_out_grad[out_idx]) *
type_convert<float>(p_in[in_idx]);
}
}
}
}
}
}
// Convert accumulator to output type and write
p_wei_grad[ii] = type_convert<WeiDataType>(v_acc);
}
}
};
// Host-side launcher for naive grouped convolution backward weight
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
CK_TILE_HOST float
naive_grouped_conv_bwd_weight(const InDataType* p_in_dev,
WeiDataType* p_wei_grad_dev,
const OutDataType* p_out_grad_dev,
ck_tile::index_t G,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t C,
std::vector<ck_tile::long_index_t> in_spatial_lengths,
std::vector<ck_tile::long_index_t> wei_spatial_lengths,
std::vector<ck_tile::long_index_t> out_spatial_lengths,
std::vector<ck_tile::long_index_t> conv_strides,
std::vector<ck_tile::long_index_t> conv_dilations,
std::vector<ck_tile::long_index_t> in_left_pads,
ck_tile::stream_config stream_config = {})
{
// Convert vectors to arrays
auto in_spatial_arr = to_array_with_default<NDimSpatial>(in_spatial_lengths);
auto wei_spatial_arr = to_array_with_default<NDimSpatial>(wei_spatial_lengths);
auto out_spatial_arr = to_array_with_default<NDimSpatial>(out_spatial_lengths);
auto conv_strides_arr = to_array_with_default<NDimSpatial>(conv_strides);
auto conv_dilations_arr = to_array_with_default<NDimSpatial>(conv_dilations);
auto in_left_pads_arr = to_array_with_default<NDimSpatial>(in_left_pads, 0);
// Calculate grid size
ck_tile::long_index_t weight_length = G * K * C;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
weight_length *= wei_spatial_lengths[i];
}
using KernelType =
naive_grouped_conv_bwd_weight_kernel<NDimSpatial, InDataType, WeiDataType, OutDataType>;
constexpr ck_tile::index_t block_size = KernelType::kBlockSize;
const ck_tile::index_t grid_size = (weight_length + block_size - 1) / block_size;
// Launch kernel
float elapsed_ms = launch_kernel(stream_config,
make_kernel(KernelType{},
dim3(grid_size),
dim3(block_size),
0, // dynamic shared memory size
p_in_dev,
p_wei_grad_dev,
p_out_grad_dev,
G,
N,
K,
C,
in_spatial_arr,
wei_spatial_arr,
out_spatial_arr,
conv_strides_arr,
conv_dilations_arr,
in_left_pads_arr));
return elapsed_ms;
}
} // namespace ck_tile

View File

@@ -0,0 +1,317 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ref/conv_common.hpp"
#include <array>
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include <hip/hip_runtime.h>
namespace ck_tile {
// Naive GPU reference kernel struct for forward grouped convolution
// Layout: Input=NDHWGC, Weight=GKZYXC, Output=NDHWGK (for 3D case)
// Input=NHWGC, Weight=GKYXC, Output=NHWGK (for 2D case)
// Input=NWGC, Weight=GKXC, Output=NWGK (for 1D case)
//
// One thread per output element, uses grid-stride loop pattern
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
struct naive_grouped_conv_fwd_kernel
{
static constexpr ck_tile::index_t kBlockSize = 256;
__device__ void
operator()(const InDataType* __restrict__ p_in,
const WeiDataType* __restrict__ p_wei,
OutDataType* __restrict__ p_out,
// Tensor dimensions
ck_tile::index_t G, // number of groups
ck_tile::index_t N, // batch size
ck_tile::index_t K, // output channels per group
ck_tile::index_t C, // input channels per group
// Input spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& in_spatial_lengths,
// Weight spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& wei_spatial_lengths,
// Output spatial dimensions
const std::array<ck_tile::long_index_t, NDimSpatial>& out_spatial_lengths,
// Convolution parameters
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_strides,
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_dilations,
const std::array<ck_tile::long_index_t, NDimSpatial>& in_left_pads) const
{
const ck_tile::long_index_t tid = get_block_id() * blockDim.x + get_thread_id();
const ck_tile::long_index_t num_threads = blockDim.x * gridDim.x;
// Calculate total output elements
ck_tile::long_index_t output_length = G * N * K;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
output_length *= out_spatial_lengths[i];
}
// Calculate strides for output tensor (NDHWGK or NHWGK or NWGK)
std::array<ck_tile::long_index_t, NDimSpatial + 3> out_strides; // N, spatial dims, G, K
ck_tile::long_index_t stride = 1;
out_strides[NDimSpatial + 2] = stride; // K stride
stride *= K;
out_strides[NDimSpatial + 1] = stride; // G stride
stride *= G;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i) // Spatial strides (reversed)
{
out_strides[i + 1] = stride;
stride *= out_spatial_lengths[i];
}
out_strides[0] = stride; // N stride
// Calculate strides for input tensor (NDHWGC or NHWGC or NWGC)
std::array<ck_tile::long_index_t, NDimSpatial + 3> in_strides;
stride = 1;
in_strides[NDimSpatial + 2] = stride; // C stride
stride *= C;
in_strides[NDimSpatial + 1] = stride; // G stride
stride *= G;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
{
in_strides[i + 1] = stride;
stride *= in_spatial_lengths[i];
}
in_strides[0] = stride; // N stride
// Calculate strides for weight tensor (GKZYXC or GKYXC or GKXC)
std::array<ck_tile::long_index_t, NDimSpatial + 3> wei_strides;
stride = 1;
wei_strides[NDimSpatial + 2] = stride; // C stride
stride *= C;
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
{
wei_strides[i + 2] = stride;
stride *= wei_spatial_lengths[i];
}
wei_strides[1] = stride; // K stride
stride *= K;
wei_strides[0] = stride; // G stride
// Grid-stride loop over all output elements
for(ck_tile::long_index_t ii = tid; ii < output_length; ii += num_threads)
{
// Decode linear index to multi-dimensional indices
ck_tile::long_index_t tmp = ii;
// Extract N (batch)
ck_tile::index_t n = tmp / out_strides[0];
tmp -= n * out_strides[0];
// Extract spatial dimensions (D, H, W)
ck_tile::index_t out_spatial_idx[6]; // Max 6 spatial dimensions
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
out_spatial_idx[i] = tmp / out_strides[i + 1];
tmp -= out_spatial_idx[i] * out_strides[i + 1];
}
// Extract G (group)
ck_tile::index_t g = tmp / out_strides[NDimSpatial + 1];
tmp -= g * out_strides[NDimSpatial + 1];
// Extract K (output channel)
ck_tile::index_t k = tmp;
// Accumulate in float
float v_acc = 0.0f;
// Loop over input channels
for(ck_tile::index_t c = 0; c < C; ++c)
{
// Loop over filter spatial dimensions
if constexpr(NDimSpatial == 1)
{
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[0]; ++x)
{
// Calculate input spatial coordinate
ck_tile::long_index_t wi =
static_cast<ck_tile::long_index_t>(out_spatial_idx[0] *
conv_strides[0]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
// Bounds check
if(wi >= 0 && wi < in_spatial_lengths[0])
{
std::array<ck_tile::index_t, 1> in_spatial = {static_cast<index_t>(wi)};
std::array<ck_tile::index_t, 1> wei_spatial = {x};
ck_tile::long_index_t in_idx =
detail::calculate_input_index<1>(n, g, c, in_spatial, in_strides);
ck_tile::long_index_t wei_idx = detail::calculate_weight_index<1>(
g, k, c, wei_spatial, wei_strides);
v_acc += type_convert<float>(p_in[in_idx]) *
type_convert<float>(p_wei[wei_idx]);
}
}
}
else if constexpr(NDimSpatial == 2)
{
for(ck_tile::index_t y = 0; y < wei_spatial_lengths[0]; ++y)
{
ck_tile::long_index_t hi =
static_cast<ck_tile::long_index_t>(out_spatial_idx[0] *
conv_strides[0]) +
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[1]; ++x)
{
ck_tile::long_index_t wi =
static_cast<ck_tile::long_index_t>(out_spatial_idx[1] *
conv_strides[1]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
// Bounds check
if(hi >= 0 && hi < in_spatial_lengths[0] && wi >= 0 &&
wi < in_spatial_lengths[1])
{
std::array<ck_tile::index_t, 2> in_spatial = {
static_cast<index_t>(hi), static_cast<index_t>(wi)};
std::array<ck_tile::index_t, 2> wei_spatial = {y, x};
ck_tile::long_index_t in_idx = detail::calculate_input_index<2>(
n, g, c, in_spatial, in_strides);
ck_tile::long_index_t wei_idx = detail::calculate_weight_index<2>(
g, k, c, wei_spatial, wei_strides);
v_acc += type_convert<float>(p_in[in_idx]) *
type_convert<float>(p_wei[wei_idx]);
}
}
}
}
else if constexpr(NDimSpatial == 3)
{
for(ck_tile::index_t z = 0; z < wei_spatial_lengths[0]; ++z)
{
ck_tile::long_index_t di =
static_cast<ck_tile::long_index_t>(out_spatial_idx[0] *
conv_strides[0]) +
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
for(ck_tile::index_t y = 0; y < wei_spatial_lengths[1]; ++y)
{
ck_tile::long_index_t hi =
static_cast<ck_tile::long_index_t>(out_spatial_idx[1] *
conv_strides[1]) +
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[2]; ++x)
{
ck_tile::long_index_t wi =
static_cast<ck_tile::long_index_t>(out_spatial_idx[2] *
conv_strides[2]) +
static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
// Bounds check
if(di >= 0 && di < in_spatial_lengths[0] && hi >= 0 &&
hi < in_spatial_lengths[1] && wi >= 0 &&
wi < in_spatial_lengths[2])
{
std::array<ck_tile::index_t, 3> in_spatial = {
static_cast<index_t>(di),
static_cast<index_t>(hi),
static_cast<index_t>(wi)};
std::array<ck_tile::index_t, 3> wei_spatial = {z, y, x};
ck_tile::long_index_t in_idx = detail::calculate_input_index<3>(
n, g, c, in_spatial, in_strides);
ck_tile::long_index_t wei_idx =
detail::calculate_weight_index<3>(
g, k, c, wei_spatial, wei_strides);
v_acc += type_convert<float>(p_in[in_idx]) *
type_convert<float>(p_wei[wei_idx]);
}
}
}
}
}
}
// Convert accumulator to output type and write
p_out[ii] = type_convert<OutDataType>(v_acc);
}
}
};
// Host-side launcher for naive grouped convolution forward
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
CK_TILE_HOST float naive_grouped_conv_fwd(const InDataType* p_in_dev,
const WeiDataType* p_wei_dev,
OutDataType* p_out_dev,
ck_tile::index_t G,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t C,
std::vector<ck_tile::long_index_t> in_spatial_lengths,
std::vector<ck_tile::long_index_t> wei_spatial_lengths,
std::vector<ck_tile::long_index_t> out_spatial_lengths,
std::vector<ck_tile::long_index_t> conv_strides,
std::vector<ck_tile::long_index_t> conv_dilations,
std::vector<ck_tile::long_index_t> in_left_pads,
ck_tile::stream_config stream_config = {})
{
// Convert vectors to arrays (std::array can be passed by value to kernel)
auto in_spatial_arr = to_array_with_default<NDimSpatial>(in_spatial_lengths);
auto wei_spatial_arr = to_array_with_default<NDimSpatial>(wei_spatial_lengths);
auto out_spatial_arr = to_array_with_default<NDimSpatial>(out_spatial_lengths);
auto conv_strides_arr = to_array_with_default<NDimSpatial>(conv_strides);
auto conv_dilations_arr = to_array_with_default<NDimSpatial>(conv_dilations);
auto in_left_pads_arr = to_array_with_default<NDimSpatial>(in_left_pads, 0);
// Calculate grid size
ck_tile::long_index_t output_length = G * N * K;
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
{
output_length *= out_spatial_lengths[i];
}
using KernelType =
naive_grouped_conv_fwd_kernel<NDimSpatial, InDataType, WeiDataType, OutDataType>;
constexpr ck_tile::index_t block_size = KernelType::kBlockSize;
const ck_tile::index_t grid_size = (output_length + block_size - 1) / block_size;
// Launch kernel
float elapsed_ms = launch_kernel(stream_config,
make_kernel(KernelType{},
dim3(grid_size),
dim3(block_size),
0, // dynamic shared memory size
p_in_dev,
p_wei_dev,
p_out_dev,
G,
N,
K,
C,
in_spatial_arr,
wei_spatial_arr,
out_spatial_arr,
conv_strides_arr,
conv_dilations_arr,
in_left_pads_arr));
return elapsed_ms;
}
} // namespace ck_tile