mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
This commit is contained in:
5
include/ck_tile/ref/README.md
Normal file
5
include/ck_tile/ref/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# reference
|
||||
|
||||
this folder contains reference implementation of a specific op. Note by including a specific header, you are including the implementation(expecially the gpu implementation) into your source code, and compile that kernel into the fatbin, hence may increase your kernel obj code length. Usually the header starts with `reference_` is a cpu reference implementation. The header starts with `naive_` contains a gpu implementation with a small launcher.
|
||||
|
||||
TODO: move `host/reference` under this folder
|
||||
95
include/ck_tile/ref/conv_common.hpp
Normal file
95
include/ck_tile/ref/conv_common.hpp
Normal file
@@ -0,0 +1,95 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Helper function to convert std::vector to std::array for kernel parameters
|
||||
template <ck_tile::index_t NDimSpatial>
|
||||
inline std::array<ck_tile::long_index_t, NDimSpatial>
|
||||
to_array(const std::vector<ck_tile::long_index_t>& vec)
|
||||
{
|
||||
std::array<ck_tile::long_index_t, NDimSpatial> arr;
|
||||
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
arr[i] = vec[i];
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
// Helper to fill missing dimensions with default value
|
||||
template <ck_tile::index_t NDimSpatial>
|
||||
inline std::array<ck_tile::long_index_t, NDimSpatial>
|
||||
to_array_with_default(const std::vector<ck_tile::long_index_t>& vec,
|
||||
ck_tile::long_index_t default_val = 1)
|
||||
{
|
||||
std::array<ck_tile::long_index_t, NDimSpatial> arr;
|
||||
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
arr[i] = (static_cast<size_t>(i) < vec.size()) ? vec[i] : default_val;
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
// Index calculation helpers for GPU reference kernels
|
||||
namespace detail {
|
||||
|
||||
// Calculate linear input index for grouped convolution
|
||||
// Layout: [N, spatial..., G, C]
|
||||
template <index_t NDimSpatial>
|
||||
inline __device__ long_index_t
|
||||
calculate_input_index(index_t n,
|
||||
index_t g,
|
||||
index_t c,
|
||||
const std::array<index_t, NDimSpatial>& spatial_idx,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& strides)
|
||||
{
|
||||
long_index_t idx = n * strides[0];
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
idx += spatial_idx[i] * strides[i + 1];
|
||||
idx += g * strides[NDimSpatial + 1] + c;
|
||||
return idx;
|
||||
}
|
||||
|
||||
// Calculate linear weight index for grouped convolution
|
||||
// Layout: [G, K, spatial..., C]
|
||||
template <index_t NDimSpatial>
|
||||
inline __device__ long_index_t
|
||||
calculate_weight_index(index_t g,
|
||||
index_t k,
|
||||
index_t c,
|
||||
const std::array<index_t, NDimSpatial>& spatial_idx,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& strides)
|
||||
{
|
||||
long_index_t idx = g * strides[0] + k * strides[1];
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
idx += spatial_idx[i] * strides[i + 2];
|
||||
idx += c * strides[NDimSpatial + 2];
|
||||
return idx;
|
||||
}
|
||||
|
||||
// Calculate linear output index for grouped convolution
|
||||
// Layout: [N, spatial..., G, K]
|
||||
template <index_t NDimSpatial>
|
||||
inline __device__ long_index_t
|
||||
calculate_output_index(index_t n,
|
||||
index_t g,
|
||||
index_t k,
|
||||
const std::array<index_t, NDimSpatial>& spatial_idx,
|
||||
const std::array<long_index_t, NDimSpatial + 3>& strides)
|
||||
{
|
||||
long_index_t idx = n * strides[0];
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
idx += spatial_idx[i] * strides[i + 1];
|
||||
idx += g * strides[NDimSpatial + 1] + k;
|
||||
return idx;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace ck_tile
|
||||
826
include/ck_tile/ref/naive_attention.hpp
Normal file
826
include/ck_tile/ref/naive_attention.hpp
Normal file
@@ -0,0 +1,826 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include <thread>
|
||||
#include <string>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum class naive_attention_layout_enum
|
||||
{
|
||||
DEFAULT, // maybe this tensor is not used, set some irrelevant value
|
||||
BSHD, // [batch, seqlen, nhead, hdim]
|
||||
BHSD, // [batch, nhead, seqlen, hdim]
|
||||
BS3HD, // [batch, nhead, 3, seqlen, hdim], used when qkv are packed
|
||||
PHSD, // [pages, nhead, page_size, hdim]
|
||||
// PHSDX, // [pages, nhead, page_size/x, hdim, x], where <# used pages>*page_size = seqlen
|
||||
PHDSX, // [pages, nhead, hdim/x, page_size, x], where <# used pages>*page_size = seqlen
|
||||
PHDS, // [pages, nhead, hdim, page_size], where <# used pages>*page_size = seqlen
|
||||
|
||||
// scale layout used for dynamic dequant
|
||||
SCALE_HS, // [nhead, tokens] or [nhead, tokens-per-group], nhe KVCache quant
|
||||
SCALE_SH, // [tokens, nhead]
|
||||
};
|
||||
|
||||
// will used to specialize kernel variation
|
||||
enum class naive_attention_variation_enum
|
||||
{
|
||||
FLASH_BATCHED = 0, // standard flash attention, or xformer/sdpa, used for training
|
||||
FLASH_GROUPED,
|
||||
DECODE_PAGED, // decode attn, where kv token from another buffer called kvcache
|
||||
};
|
||||
|
||||
enum class naive_attention_quant_algo
|
||||
{
|
||||
NO = 0,
|
||||
KV_8BIT_PERHEAD = 1,
|
||||
// FP8/INT8 quant for KVCache, per-token quant
|
||||
// [num_tokens, nhead, hdim] -> [nhead, num_tokens]
|
||||
KV_8BIT_PERTOKEN = 2,
|
||||
};
|
||||
|
||||
// TODO: for simplicity, this will be used as host/device arg
|
||||
struct naive_attention_fwd_args
|
||||
{
|
||||
void* q_ptr;
|
||||
void* k_ptr;
|
||||
void* v_ptr;
|
||||
void* o_ptr;
|
||||
void* context_len_ptr; // [batch] used when seqlen kv come from a pointer(each element is a
|
||||
// number, not cumsum)
|
||||
void* page_table_ptr; // [batch, max_pages_per_seq] seqlen_kv is in different block(paged attn)
|
||||
void* kscale_ptr; // [nhead, max_kv_tokens] used for kvcache dequant
|
||||
void* vscale_ptr; // [nhead, max_kv_tokens] used for kvcache dequant
|
||||
float scale_s;
|
||||
int hdim;
|
||||
int hdim_v; // could be cross-attn, where V and Q/K hdim are different
|
||||
int batch_q;
|
||||
int batch_kv;
|
||||
int batch_ratio_kv; // batch_q / batch_kv
|
||||
int seqlen_q; // in decode case, this should be 1
|
||||
int seqlen_kv; // if context_len_ptr is not nullptr, ignore this field
|
||||
int nhead_q;
|
||||
int nhead_kv;
|
||||
int nhead_ratio_kv; // nhead_q / nhead_kv
|
||||
int page_size; // if paged, the seqlen-kv per each block
|
||||
int max_pages_per_seq;
|
||||
int max_kv_tokens; // used as stride to access kv scale ptr
|
||||
};
|
||||
|
||||
// this is trait for host API
|
||||
struct naive_attention_fwd_traits
|
||||
{
|
||||
std::string q_type;
|
||||
std::string k_type;
|
||||
std::string v_type;
|
||||
std::string o_type;
|
||||
std::string q_layout;
|
||||
std::string k_layout;
|
||||
std::string v_layout;
|
||||
std::string o_layout;
|
||||
int variation; // sync with naive_attention_variation_enum
|
||||
int quant_algo; // sync with naive_attention_quant_algo
|
||||
};
|
||||
|
||||
// this is trait for kernel template
|
||||
template <naive_attention_variation_enum variation_, naive_attention_quant_algo quant_algo_>
|
||||
struct naive_attention_fwd_kernel_traits
|
||||
{
|
||||
static constexpr naive_attention_variation_enum variation = variation_;
|
||||
static constexpr naive_attention_quant_algo quant_algo = quant_algo_;
|
||||
};
|
||||
|
||||
// for simplicity, please do not use const-reference type for the template type
|
||||
template <typename QType,
|
||||
typename KType,
|
||||
typename VType,
|
||||
typename OType,
|
||||
typename AccType,
|
||||
typename KVScaleType,
|
||||
naive_attention_layout_enum QLayout,
|
||||
naive_attention_layout_enum KLayout,
|
||||
naive_attention_layout_enum VLayout,
|
||||
naive_attention_layout_enum OLayout,
|
||||
naive_attention_layout_enum KScaleLayout,
|
||||
naive_attention_layout_enum VScaleLayout,
|
||||
typename Traits>
|
||||
struct naive_attention_fwd_kernel
|
||||
{
|
||||
static constexpr bool is_kvcache_i8 =
|
||||
std::is_same_v<KType, int8_t> && std::is_same_v<VType, int8_t>;
|
||||
static constexpr bool is_kvcache_fp8 =
|
||||
std::is_same_v<KType, fp8_t> && std::is_same_v<VType, fp8_t>;
|
||||
|
||||
static constexpr int v_per_token_quant_group_size = 64;
|
||||
static constexpr int kBlockSize = 256;
|
||||
// TODO: hardcode
|
||||
using SoftmaxType = float; // always using float to do softmax compute
|
||||
using QuantComputeType = float; // used for quant/dequant scale compute
|
||||
using QCompute = KType; // src A of gemm1, same type as K
|
||||
using PType = VType; // src A of gemm2, same type as V
|
||||
using OAccType = float; // always float, in case int8 FA
|
||||
|
||||
using p_vec_type = ext_vector_t<PType, 16 / sizeof(PType)>;
|
||||
static constexpr int p_vec_elem = vector_traits<p_vec_type>::vector_size;
|
||||
|
||||
// clang-format off
|
||||
template <typename T_> struct scale_max { static constexpr float value = 1; /* dummy code */ };
|
||||
template <> struct scale_max<int8_t> { static constexpr float value = 127.0; };
|
||||
template <> struct scale_max<fp8_t> { static constexpr float value = 240.0; };
|
||||
// clang-format on
|
||||
|
||||
__host__ __device__ naive_attention_fwd_kernel() {}
|
||||
|
||||
template <typename T, naive_attention_layout_enum Layout>
|
||||
struct addresser
|
||||
{
|
||||
int b, s, h, d; // batch, seqlen, nhead, hdim
|
||||
T* base_ptr;
|
||||
__device__ addresser(int b_, int s_, int h_, int d_, void* base_ptr_)
|
||||
: b(b_), s(s_), h(h_), d(d_), base_ptr(reinterpret_cast<T*>(base_ptr_))
|
||||
{
|
||||
}
|
||||
|
||||
// TODO: all the batch/nhead offset will accumulate to the base pointer
|
||||
__device__ T* get_base(int i_b, int i_h)
|
||||
{
|
||||
if constexpr(Layout == naive_attention_layout_enum::BSHD)
|
||||
return base_ptr + i_b * s * h * d + i_h * d;
|
||||
else if constexpr(Layout == naive_attention_layout_enum::BHSD)
|
||||
return base_ptr + i_b * s * h * d + i_h * s * d;
|
||||
}
|
||||
|
||||
__device__ int get_offset(int i_s, int i_d)
|
||||
{
|
||||
if constexpr(Layout == naive_attention_layout_enum::BSHD)
|
||||
return i_s * h * d + i_d;
|
||||
else if constexpr(Layout == naive_attention_layout_enum::BHSD)
|
||||
return i_s * d + i_d;
|
||||
}
|
||||
|
||||
// below set of API will directly use pointer inside this struct
|
||||
__device__ void init(int i_b, int i_h) { base_ptr = get_base(i_b, i_h); }
|
||||
__device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; }
|
||||
__device__ void store(T value, int i_s, int i_d) { base_ptr[get_offset(i_s, i_d)] = value; }
|
||||
};
|
||||
|
||||
template <typename T, naive_attention_layout_enum Layout>
|
||||
struct page_addresser
|
||||
{
|
||||
int s, h, d; // page_size, nhead, hdim
|
||||
static constexpr int x = 16 / sizeof(T); // pack 4 dword
|
||||
T* base_ptr;
|
||||
int* page_table_ptr; // TODO: page table always int
|
||||
int i_h; // store current head
|
||||
|
||||
__device__ page_addresser(int s_, int h_, int d_, void* base_ptr_, void* pptr_)
|
||||
: s(s_),
|
||||
h(h_),
|
||||
d(d_),
|
||||
base_ptr(reinterpret_cast<T*>(base_ptr_)),
|
||||
page_table_ptr(reinterpret_cast<int*>(pptr_))
|
||||
{
|
||||
}
|
||||
|
||||
__device__ int64_t get_phy_page_idx(int i_s)
|
||||
{
|
||||
// dynamic compute page idx is simple but slow
|
||||
int page_idx = i_s / s;
|
||||
int phy = page_table_ptr[page_idx];
|
||||
return static_cast<int64_t>(phy);
|
||||
}
|
||||
|
||||
__device__ int get_phy_page_offset(int i_s)
|
||||
{
|
||||
// dynamic compute page idx is simple but slow
|
||||
return i_s % s;
|
||||
}
|
||||
|
||||
__device__ int64_t get_offset(int i_s, int i_d)
|
||||
{
|
||||
int page_offset = get_phy_page_offset(i_s);
|
||||
int64_t page_idx = get_phy_page_idx(i_s);
|
||||
int64_t base_ = page_idx * h * s * d;
|
||||
if constexpr(Layout == naive_attention_layout_enum::PHSD)
|
||||
return static_cast<int64_t>(i_h * s * d + page_offset * d + i_d) + base_;
|
||||
else if constexpr(Layout == naive_attention_layout_enum::PHDSX)
|
||||
{
|
||||
int d_r = i_d / x;
|
||||
int d_x = i_d % x;
|
||||
return static_cast<int64_t>(i_h * d * s + d_r * s * x + page_offset * x + d_x) +
|
||||
base_;
|
||||
}
|
||||
else if constexpr(Layout == naive_attention_layout_enum::PHDS)
|
||||
{
|
||||
return static_cast<int64_t>(i_h * d * s + i_d * s + page_offset) + base_;
|
||||
}
|
||||
}
|
||||
|
||||
// below set of API will directly use pointer inside this struct
|
||||
__device__ void init(int /*i_b*/, int i_h_) { i_h = i_h_; }
|
||||
__device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; }
|
||||
__device__ void store(T /*value*/, int /*i_s*/, int /*i_d*/) {}
|
||||
};
|
||||
|
||||
template <typename T, naive_attention_layout_enum Layout>
|
||||
struct kvscale_addresser
|
||||
{
|
||||
int s, h, d; // seqlen(tokens), nhead, hdim
|
||||
T* base_ptr;
|
||||
__device__ kvscale_addresser(int s_, int h_, int d_, void* p_)
|
||||
: s(s_), h(h_), d(d_), base_ptr(reinterpret_cast<T*>(p_))
|
||||
{
|
||||
}
|
||||
__device__ int get_offset(int i_s, int i_h, int i_d)
|
||||
{
|
||||
if constexpr(Layout == naive_attention_layout_enum::SCALE_HS)
|
||||
{
|
||||
// [nhead, tokens]
|
||||
(void)i_d;
|
||||
return i_h * s + i_s;
|
||||
}
|
||||
else if constexpr(Layout == naive_attention_layout_enum::DEFAULT)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
// [h, 2, d]
|
||||
// return i_h * 2 * d + i_kv * d + i_d;
|
||||
}
|
||||
__device__ T load(int i_s, int i_h, int i_d) { return base_ptr[get_offset(i_s, i_h, i_d)]; }
|
||||
};
|
||||
|
||||
__device__ __host__ static constexpr int get_block_size() { return kBlockSize; }
|
||||
|
||||
// for simpliciy, 1 WG always compute 1 token along q, compute all token along kv
|
||||
// compute all hdim from q, compute WG_SIZE hdim from v
|
||||
// 1) in prefill case, seqlen_q >= 1, seqlen_kv >= 1, batch_q=batch_kv
|
||||
// 2) in decode case, seqlen_q = 1, batch_q is input num-tokens, batch_kv is 1
|
||||
// 3) in paged-attn case, we still use 1 WG compute all the seqlen-kv for simplicity
|
||||
// TODO: could support split-kv to validate intermediate logsum
|
||||
__host__ static dim3 get_grid_size(naive_attention_fwd_args args)
|
||||
{
|
||||
constexpr int wg_size = get_block_size();
|
||||
auto g =
|
||||
dim3((args.hdim_v + wg_size - 1) / wg_size, args.seqlen_q, args.batch_q * args.nhead_q);
|
||||
return g;
|
||||
}
|
||||
|
||||
// reduce single pixel within a wave
|
||||
template <typename T, typename F>
|
||||
__device__ constexpr T wave_reduce(T local, F reduce_f)
|
||||
{
|
||||
// constexpr int wave_size = 64;
|
||||
constexpr int reduce_stage = 6; // 1<<6=64
|
||||
T v_local = local;
|
||||
#pragma unroll
|
||||
for(int i_stage = 0; i_stage < reduce_stage; i_stage++)
|
||||
{
|
||||
int src_lane = __lane_id() ^ (1 << i_stage);
|
||||
int32_t v_remote_tmp =
|
||||
__builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
|
||||
T v_remote = bit_cast<T>(v_remote_tmp);
|
||||
v_local = reduce_f(v_local, v_remote);
|
||||
}
|
||||
return v_local;
|
||||
}
|
||||
|
||||
// Note: this function must be called after wave_reduce
|
||||
// Note: better not use this under if...else... with thread divergence (syncthreads)
|
||||
template <typename T, typename F>
|
||||
__device__ constexpr T cross_wave_reduce(T local, F reduce_f, T* smem)
|
||||
{
|
||||
constexpr int waves = 4;
|
||||
constexpr int wave_size = 64;
|
||||
int lane_id = threadIdx.x % wave_size;
|
||||
|
||||
__syncthreads();
|
||||
smem[threadIdx.x] = local;
|
||||
__syncthreads();
|
||||
|
||||
// the data within single wave is the same
|
||||
// but for simplicity, we still use data from each lane.
|
||||
T v_local = smem[lane_id];
|
||||
#pragma unroll
|
||||
for(int i_stage = 1; i_stage < waves; i_stage++)
|
||||
{
|
||||
T v_remote = smem[i_stage * wave_size + lane_id];
|
||||
v_local = reduce_f(v_local, v_remote);
|
||||
}
|
||||
return v_local;
|
||||
}
|
||||
|
||||
// kernel entry point
|
||||
__device__ void operator()(naive_attention_fwd_args args)
|
||||
{
|
||||
constexpr int wg_size = get_block_size();
|
||||
__shared__ char smem[wg_size * 4 * sizeof(float)]; // should enough
|
||||
char* smem_quant_q = smem + wg_size * 2 * sizeof(float); // second half, should enough
|
||||
int i_dv = blockIdx.x * wg_size + threadIdx.x; // index of hdim_v
|
||||
int i_sq = blockIdx.y; // index of seqlen_q
|
||||
int i_batch = blockIdx.z; // index of batch_q * nhead_q
|
||||
int i_bq = i_batch / args.nhead_q; // index of batch_q
|
||||
int i_hq = i_batch % args.nhead_q; // index of nhead_q
|
||||
|
||||
int i_bk = i_bq / args.batch_ratio_kv;
|
||||
int i_hk = i_hq / args.nhead_ratio_kv;
|
||||
|
||||
void* page_table_ptr = [&]() {
|
||||
if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
|
||||
{
|
||||
return reinterpret_cast<int*>(args.page_table_ptr) + i_bq * args.max_pages_per_seq;
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}();
|
||||
|
||||
auto q_addr = [&]() {
|
||||
if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
|
||||
{
|
||||
return addresser<QType, QLayout>{
|
||||
args.batch_q, args.seqlen_q, args.nhead_q, args.hdim, args.q_ptr};
|
||||
}
|
||||
else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
|
||||
{
|
||||
return addresser<QType, QLayout>{
|
||||
args.batch_q, args.seqlen_q, args.nhead_q, args.hdim, args.q_ptr};
|
||||
}
|
||||
}();
|
||||
auto k_addr = [&]() {
|
||||
if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
|
||||
{
|
||||
return addresser<KType, KLayout>{
|
||||
args.batch_kv, args.seqlen_kv, args.nhead_kv, args.hdim, args.k_ptr};
|
||||
}
|
||||
else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
|
||||
{
|
||||
return page_addresser<KType, KLayout>{
|
||||
args.page_size, args.nhead_kv, args.hdim, args.k_ptr, page_table_ptr};
|
||||
}
|
||||
}();
|
||||
auto v_addr = [&]() {
|
||||
if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
|
||||
{
|
||||
return addresser<VType, VLayout>{
|
||||
args.batch_kv, args.seqlen_kv, args.nhead_kv, args.hdim_v, args.v_ptr};
|
||||
}
|
||||
else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
|
||||
{
|
||||
return page_addresser<VType, VLayout>{
|
||||
args.page_size, args.nhead_kv, args.hdim_v, args.v_ptr, page_table_ptr};
|
||||
}
|
||||
}();
|
||||
auto o_addr = [&]() {
|
||||
if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
|
||||
{
|
||||
return addresser<OType, OLayout>{
|
||||
args.batch_q, args.seqlen_q, args.nhead_q, args.hdim_v, args.o_ptr};
|
||||
}
|
||||
else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
|
||||
{
|
||||
return addresser<OType, OLayout>{
|
||||
args.batch_q, args.seqlen_q, args.nhead_q, args.hdim_v, args.o_ptr};
|
||||
}
|
||||
}();
|
||||
|
||||
q_addr.init(i_bq, i_hq);
|
||||
k_addr.init(i_bk, i_hk);
|
||||
v_addr.init(i_bk, i_hk);
|
||||
o_addr.init(i_bq, i_hq);
|
||||
|
||||
auto f_max = [](auto x_, auto y_) { return max(x_, y_); };
|
||||
auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
|
||||
auto f_absmax_f32 = [](float v_0_, float v_1_) {
|
||||
// float rtn;
|
||||
// asm volatile("v_max_f32 %0, abs(%1), abs(%2)" : "=v"(rtn) : "v"(v_0_), "v"(v_1_));
|
||||
// return rtn;
|
||||
return max(abs(v_0_), abs(v_1_));
|
||||
};
|
||||
|
||||
int seqlen_kv = [&]() {
|
||||
if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
|
||||
{
|
||||
return args.seqlen_kv;
|
||||
}
|
||||
else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
|
||||
{
|
||||
return reinterpret_cast<int*>(args.context_len_ptr)[i_bq];
|
||||
}
|
||||
}();
|
||||
|
||||
SoftmaxType row_max = -numeric<SoftmaxType>::infinity();
|
||||
SoftmaxType l{0};
|
||||
// AccType o_acc = {0};
|
||||
OAccType o_acc = {0};
|
||||
|
||||
int sk_loops = (seqlen_kv + wg_size - 1) / wg_size;
|
||||
QuantComputeType q_dequant_scale = .0f;
|
||||
kvscale_addresser<KVScaleType, KScaleLayout> kscale_addr{
|
||||
args.max_kv_tokens, args.nhead_kv, args.hdim, args.kscale_ptr};
|
||||
kvscale_addresser<KVScaleType, VScaleLayout> vscale_addr{
|
||||
args.max_kv_tokens, args.nhead_kv, args.hdim_v, args.vscale_ptr};
|
||||
|
||||
if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
|
||||
{
|
||||
// AccType is i32 now, seqlen_q = 1, hdim up to 256
|
||||
AccType q = 0;
|
||||
AccType k_s = 0;
|
||||
if(static_cast<int>(threadIdx.x) < args.hdim)
|
||||
{
|
||||
q = type_convert<AccType>(q_addr.load(0, threadIdx.x));
|
||||
k_s = type_convert<AccType>(kscale_addr.load(i_hk, threadIdx.x, 0));
|
||||
}
|
||||
// 1) we apply the k scale to q
|
||||
AccType q_forwarded = q * k_s;
|
||||
|
||||
// 2) apply smooth-quant
|
||||
// find absmax
|
||||
AccType qf_max = wave_reduce(q_forwarded, f_absmax_f32);
|
||||
qf_max = cross_wave_reduce(qf_max, f_absmax_f32, reinterpret_cast<AccType*>(smem));
|
||||
|
||||
// per-token scale
|
||||
q_dequant_scale = type_convert<QuantComputeType>(qf_max) / scale_max<QCompute>::value;
|
||||
|
||||
// devide by scale
|
||||
q = q / q_dequant_scale;
|
||||
|
||||
// fp32->i8
|
||||
QCompute quantized_q = static_cast<QCompute>(q);
|
||||
__syncthreads();
|
||||
reinterpret_cast<QCompute*>(smem)[threadIdx.x] = quantized_q;
|
||||
__syncthreads();
|
||||
|
||||
// after above process, we have 2 data
|
||||
// 1) int8 q data stored in smem(no need to reload)
|
||||
// 2) per-token scale q_dequant_scale, to be mul after 1st gemm
|
||||
}
|
||||
else if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERTOKEN)
|
||||
{
|
||||
if(std::is_same_v<QType, fp16_t> || std::is_same_v<QType, bf16_t>)
|
||||
{
|
||||
// dyanmic quant q here
|
||||
float q = 0;
|
||||
if(static_cast<int>(threadIdx.x) < args.hdim)
|
||||
{
|
||||
q = type_convert<float>(q_addr.load(i_sq, threadIdx.x));
|
||||
}
|
||||
|
||||
// apply smooth-quant
|
||||
// find absmax
|
||||
float q_max = wave_reduce(q, f_absmax_f32);
|
||||
q_max = cross_wave_reduce(q_max, f_absmax_f32, reinterpret_cast<float*>(smem));
|
||||
|
||||
// per-token scale
|
||||
q_dequant_scale =
|
||||
type_convert<QuantComputeType>(q_max) / scale_max<QCompute>::value;
|
||||
|
||||
// devide by scale
|
||||
q = q / q_dequant_scale;
|
||||
|
||||
QCompute quantized_q = type_convert<QCompute>(q);
|
||||
__syncthreads();
|
||||
reinterpret_cast<QCompute*>(smem_quant_q)[threadIdx.x] = quantized_q;
|
||||
__syncthreads();
|
||||
|
||||
// after above process, we have 2 data
|
||||
// 1) fp8 q data stored in smem(no need to reload from global)
|
||||
// 2) per-token scale q_dequant_scale, to be mul after 1st gemm
|
||||
}
|
||||
}
|
||||
|
||||
for(int i_loop1 = 0; i_loop1 < sk_loops; i_loop1++)
|
||||
{
|
||||
int i_sk = i_loop1 * wg_size + threadIdx.x;
|
||||
// gemm-1
|
||||
SoftmaxType s_softmax = -numeric<SoftmaxType>::infinity();
|
||||
if(i_sk < seqlen_kv)
|
||||
{
|
||||
AccType s_acc{0}; // clear for every loop
|
||||
for(auto i_dq = 0; i_dq < args.hdim; i_dq++)
|
||||
{
|
||||
auto q = [&]() {
|
||||
if constexpr(Traits::quant_algo ==
|
||||
naive_attention_quant_algo::KV_8BIT_PERHEAD ||
|
||||
Traits::quant_algo ==
|
||||
naive_attention_quant_algo::KV_8BIT_PERTOKEN)
|
||||
{
|
||||
return reinterpret_cast<QCompute*>(smem_quant_q)[i_dq];
|
||||
}
|
||||
else
|
||||
return q_addr.load(i_sq, i_dq); // q will have duplicate load
|
||||
}();
|
||||
auto k = [&]() { return k_addr.load(i_sk, i_dq); }();
|
||||
|
||||
s_acc += type_convert<AccType>(q) * type_convert<AccType>(k);
|
||||
}
|
||||
// scale
|
||||
s_softmax = type_convert<SoftmaxType>(s_acc);
|
||||
s_softmax *=
|
||||
type_convert<SoftmaxType>(args.scale_s * ck_tile::log2e_v<SoftmaxType>);
|
||||
if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
|
||||
{
|
||||
s_softmax *= q_dequant_scale; // post scale the per-token factor
|
||||
}
|
||||
else if constexpr(Traits::quant_algo ==
|
||||
naive_attention_quant_algo::KV_8BIT_PERTOKEN)
|
||||
{
|
||||
SoftmaxType k_per_token_scale =
|
||||
type_convert<SoftmaxType>(kscale_addr.load(i_sk, i_hk, 0));
|
||||
s_softmax *= q_dequant_scale;
|
||||
s_softmax *= k_per_token_scale;
|
||||
}
|
||||
}
|
||||
|
||||
// s->p
|
||||
QuantComputeType p_dequant_scale = 1.;
|
||||
{
|
||||
// softmax, find max
|
||||
SoftmaxType old_max = row_max;
|
||||
SoftmaxType cur_max = wave_reduce(s_softmax, f_max);
|
||||
|
||||
cur_max = cross_wave_reduce(cur_max, f_max, reinterpret_cast<SoftmaxType*>(smem));
|
||||
row_max = max(old_max, cur_max); // update row_max
|
||||
// softmax, exp(i_elem - max)
|
||||
SoftmaxType p_compute = __builtin_amdgcn_exp2f(s_softmax - row_max);
|
||||
|
||||
// compute exp_sum
|
||||
SoftmaxType row_sum = wave_reduce(p_compute, f_sum);
|
||||
row_sum = cross_wave_reduce(row_sum, f_sum, reinterpret_cast<SoftmaxType*>(smem));
|
||||
|
||||
// l, pre-scall o_acc
|
||||
SoftmaxType tmp = __builtin_amdgcn_exp2f(old_max - row_max);
|
||||
l = tmp * l + row_sum;
|
||||
o_acc = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp);
|
||||
|
||||
// prepare the p_compute into smem, to let every thread read same p_compute and do
|
||||
// 2nd gemm
|
||||
if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
|
||||
{
|
||||
QuantComputeType v_s = 0;
|
||||
if(static_cast<int>(threadIdx.x) < args.hdim_v)
|
||||
{
|
||||
v_s =
|
||||
type_convert<QuantComputeType>(vscale_addr.load(i_hk, threadIdx.x, 1));
|
||||
}
|
||||
|
||||
// 1) we apply the v scale to p
|
||||
QuantComputeType p_forwarded = p_compute * v_s;
|
||||
|
||||
// 2) apply smooth-quant
|
||||
// find absmax
|
||||
QuantComputeType pf_max = wave_reduce(p_forwarded, f_absmax_f32);
|
||||
pf_max = cross_wave_reduce(
|
||||
pf_max, f_absmax_f32, reinterpret_cast<QuantComputeType*>(smem));
|
||||
|
||||
// per-token scale
|
||||
p_dequant_scale = pf_max / scale_max<PType>::value; // 127.0;
|
||||
|
||||
// devide by scale
|
||||
p_compute = p_compute / p_dequant_scale;
|
||||
|
||||
// fp32->i8
|
||||
PType quantized_p = static_cast<PType>(p_compute);
|
||||
__syncthreads();
|
||||
reinterpret_cast<PType*>(smem)[threadIdx.x] = quantized_p;
|
||||
__syncthreads();
|
||||
// after above process, we have 2 data
|
||||
// 1) int8 p data stored in smem(no need to reload)
|
||||
// 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
|
||||
}
|
||||
else if constexpr(Traits::quant_algo ==
|
||||
naive_attention_quant_algo::KV_8BIT_PERTOKEN)
|
||||
{
|
||||
// forward apply the v scale to p_compute, this is compute friendly
|
||||
auto v_scale = type_convert<QuantComputeType>(vscale_addr.load(i_sk, i_hk, 0));
|
||||
p_compute *= v_scale;
|
||||
// smooth-quant
|
||||
// find absmax
|
||||
QuantComputeType p_max = wave_reduce(p_compute, f_absmax_f32);
|
||||
p_max = cross_wave_reduce(
|
||||
p_max, f_absmax_f32, reinterpret_cast<QuantComputeType*>(smem));
|
||||
|
||||
// per-token scale
|
||||
p_dequant_scale = p_max / scale_max<PType>::value; // 240.0;
|
||||
|
||||
// devide by scale
|
||||
p_compute = p_compute / p_dequant_scale;
|
||||
|
||||
// fp32->i8
|
||||
PType quantized_p = type_convert<PType>(p_compute);
|
||||
__syncthreads();
|
||||
reinterpret_cast<PType*>(smem)[threadIdx.x] = quantized_p;
|
||||
__syncthreads();
|
||||
// after above process, we have 2 data
|
||||
// 1) fp8_t p data stored in smem(no need to reload)
|
||||
// 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
|
||||
}
|
||||
else
|
||||
{
|
||||
__syncthreads();
|
||||
reinterpret_cast<PType*>(smem)[threadIdx.x] = type_convert<PType>(p_compute);
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// gemm-2, simple loop over vector by vector
|
||||
constexpr int gemm_2_loop = wg_size / p_vec_elem;
|
||||
{
|
||||
AccType o_acc_local = {0};
|
||||
int sk_start = i_loop1 * wg_size; // we start from the first seqlen_kv element
|
||||
for(int i_loop2 = 0; i_loop2 < gemm_2_loop; i_loop2++)
|
||||
{
|
||||
p_vec_type p_vec = reinterpret_cast<p_vec_type*>(smem)[i_loop2];
|
||||
#pragma unroll
|
||||
for(int i_j = 0; i_j < p_vec_elem; i_j++)
|
||||
{
|
||||
int sv_offset = i_loop2 * p_vec_elem + i_j;
|
||||
int i_sv = sk_start + sv_offset;
|
||||
|
||||
VType v = 0;
|
||||
if(i_dv < args.hdim_v && i_sv < seqlen_kv)
|
||||
{
|
||||
v = v_addr.load(i_sv, i_dv);
|
||||
}
|
||||
|
||||
AccType v_compute = [&]() { return type_convert<AccType>(v); }();
|
||||
|
||||
o_acc_local += type_convert<AccType>(p_vec[i_j]) * v_compute;
|
||||
}
|
||||
}
|
||||
|
||||
OAccType post_scale_o_acc_local = [&]() {
|
||||
if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
|
||||
{
|
||||
// apply pr scale to local acc
|
||||
return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) *
|
||||
p_dequant_scale);
|
||||
}
|
||||
else if constexpr(Traits::quant_algo ==
|
||||
naive_attention_quant_algo::KV_8BIT_PERTOKEN)
|
||||
{
|
||||
// apply pr scale to local acc
|
||||
return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) *
|
||||
p_dequant_scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
return type_convert<OAccType>(o_acc_local);
|
||||
}
|
||||
}();
|
||||
o_acc += post_scale_o_acc_local;
|
||||
}
|
||||
}
|
||||
|
||||
// post scale o_acc
|
||||
{
|
||||
SoftmaxType tmp = l == 0.f ? 0.f : 1.f / l; // in case masking
|
||||
o_acc = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp);
|
||||
}
|
||||
|
||||
// store O
|
||||
if(i_dv < args.hdim_v)
|
||||
o_addr.store(type_convert<OType>(o_acc), i_sq, i_dv);
|
||||
}
|
||||
};
|
||||
|
||||
#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_() \
|
||||
{ \
|
||||
using ktraits_ = naive_attention_fwd_kernel_traits< \
|
||||
static_cast<naive_attention_variation_enum>(variation_), \
|
||||
static_cast<naive_attention_quant_algo>(quant_algo_)>; \
|
||||
using k_ = naive_attention_fwd_kernel<q_type_, \
|
||||
k_type_, \
|
||||
v_type_, \
|
||||
o_type_, \
|
||||
acc_type_, \
|
||||
kvscale_type_, \
|
||||
q_layout_, \
|
||||
k_layout_, \
|
||||
v_layout_, \
|
||||
o_layout_, \
|
||||
k_scale_layout_, \
|
||||
v_scale_layout_, \
|
||||
ktraits_>; \
|
||||
dim3 grids = k_::get_grid_size(a); \
|
||||
r = ck_tile::launch_kernel(s, \
|
||||
ck_tile::make_kernel(k_{}, grids, k_::get_block_size(), 0, a)); \
|
||||
}
|
||||
|
||||
#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_() \
|
||||
if(t.variation == 0 && t.q_layout == "bshd" && t.k_layout == "bshd" && t.v_layout == "bshd" && \
|
||||
t.o_layout == "bshd") \
|
||||
{ \
|
||||
constexpr auto q_layout_ = naive_attention_layout_enum::BSHD; \
|
||||
constexpr auto k_layout_ = naive_attention_layout_enum::BSHD; \
|
||||
constexpr auto v_layout_ = naive_attention_layout_enum::BSHD; \
|
||||
constexpr auto o_layout_ = naive_attention_layout_enum::BSHD; \
|
||||
constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
|
||||
constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
|
||||
constexpr int variation_ = 0; \
|
||||
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
|
||||
} \
|
||||
else if(t.variation == 0 && t.q_layout == "bhsd" && t.k_layout == "bhsd" && \
|
||||
t.v_layout == "bhsd" && t.o_layout == "bhsd") \
|
||||
{ \
|
||||
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
|
||||
constexpr auto k_layout_ = naive_attention_layout_enum::BHSD; \
|
||||
constexpr auto v_layout_ = naive_attention_layout_enum::BHSD; \
|
||||
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
|
||||
constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
|
||||
constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
|
||||
constexpr int variation_ = 0; \
|
||||
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
|
||||
} \
|
||||
else if(t.variation == 2 && t.q_layout == "bhsd" && t.k_layout == "phdsx" && \
|
||||
t.v_layout == "phds" && t.o_layout == "bhsd") \
|
||||
{ \
|
||||
constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
|
||||
constexpr auto k_layout_ = naive_attention_layout_enum::PHDSX; \
|
||||
constexpr auto v_layout_ = naive_attention_layout_enum::PHDS; \
|
||||
constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
|
||||
constexpr auto k_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \
|
||||
constexpr auto v_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \
|
||||
constexpr int variation_ = 2; \
|
||||
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
|
||||
}
|
||||
|
||||
//
|
||||
CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t,
|
||||
naive_attention_fwd_args a,
|
||||
ck_tile::stream_config s)
|
||||
{
|
||||
float r = -1;
|
||||
// TODO: do not explicitly create too much instance!
|
||||
if(t.q_type == "fp16" && t.k_type == "fp16" && t.v_type == "fp16" && t.o_type == "fp16" &&
|
||||
t.quant_algo == 0)
|
||||
{
|
||||
using q_type_ = fp16_t;
|
||||
using k_type_ = fp16_t;
|
||||
using v_type_ = fp16_t;
|
||||
using o_type_ = fp16_t;
|
||||
using acc_type_ = float;
|
||||
using kvscale_type_ = float;
|
||||
constexpr int quant_algo_ = 0;
|
||||
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
|
||||
}
|
||||
else if(t.q_type == "bf16" && t.k_type == "bf16" && t.v_type == "bf16" && t.o_type == "bf16" &&
|
||||
t.quant_algo == 0)
|
||||
{
|
||||
using q_type_ = bf16_t;
|
||||
using k_type_ = bf16_t;
|
||||
using v_type_ = bf16_t;
|
||||
using o_type_ = bf16_t;
|
||||
using acc_type_ = float;
|
||||
using kvscale_type_ = float;
|
||||
constexpr int quant_algo_ = 0;
|
||||
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
|
||||
}
|
||||
else if(t.q_type == "bf16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "bf16" &&
|
||||
t.quant_algo == 2)
|
||||
{
|
||||
using q_type_ = bf16_t;
|
||||
using k_type_ = fp8_t;
|
||||
using v_type_ = fp8_t;
|
||||
using o_type_ = bf16_t;
|
||||
using acc_type_ = float; // NOTE!
|
||||
using kvscale_type_ = float;
|
||||
constexpr int quant_algo_ = 2;
|
||||
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
|
||||
}
|
||||
else if(t.q_type == "fp16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "fp16" &&
|
||||
t.quant_algo == 2)
|
||||
{
|
||||
using q_type_ = fp16_t;
|
||||
using k_type_ = fp8_t;
|
||||
using v_type_ = fp8_t;
|
||||
using o_type_ = fp16_t;
|
||||
using acc_type_ = float; // NOTE!
|
||||
using kvscale_type_ = float;
|
||||
constexpr int quant_algo_ = 2;
|
||||
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
|
||||
}
|
||||
else if(t.q_type == "bf16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "bf16" &&
|
||||
t.quant_algo == 2)
|
||||
{
|
||||
using q_type_ = bf16_t;
|
||||
using k_type_ = int8_t;
|
||||
using v_type_ = int8_t;
|
||||
using o_type_ = bf16_t;
|
||||
using acc_type_ = int32_t; // NOTE!
|
||||
using kvscale_type_ = float;
|
||||
constexpr int quant_algo_ = 2;
|
||||
CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_();
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
#undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
|
||||
#undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_
|
||||
|
||||
} // namespace ck_tile
|
||||
360
include/ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp
Normal file
360
include/ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp
Normal file
@@ -0,0 +1,360 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ref/conv_common.hpp"
|
||||
#include <array>
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Naive GPU reference kernel struct for backward data grouped convolution
|
||||
// Computes gradient with respect to input
|
||||
// Layout: Input_grad=NDHWGC, Weight=GKZYXC, Output_grad=NDHWGK (for 3D case)
|
||||
// Input_grad=NHWGC, Weight=GKYXC, Output_grad=NHWGK (for 2D case)
|
||||
// Input_grad=NWGC, Weight=GKXC, Output_grad=NWGK (for 1D case)
|
||||
//
|
||||
// One thread per input element, uses grid-stride loop pattern
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
struct naive_grouped_conv_bwd_data_kernel
|
||||
{
|
||||
static constexpr ck_tile::index_t kBlockSize = 256;
|
||||
|
||||
__device__ void
|
||||
operator()(InDataType* __restrict__ p_in_grad,
|
||||
const WeiDataType* __restrict__ p_wei,
|
||||
const OutDataType* __restrict__ p_out_grad,
|
||||
// Tensor dimensions
|
||||
ck_tile::index_t G, // number of groups
|
||||
ck_tile::index_t N, // batch size
|
||||
ck_tile::index_t K, // output channels per group
|
||||
ck_tile::index_t C, // input channels per group
|
||||
// Input spatial dimensions
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& in_spatial_lengths,
|
||||
// Weight spatial dimensions
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& wei_spatial_lengths,
|
||||
// Output spatial dimensions
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& out_spatial_lengths,
|
||||
// Convolution parameters
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_strides,
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_dilations,
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& in_left_pads) const
|
||||
{
|
||||
const ck_tile::long_index_t tid = get_block_id() * blockDim.x + get_thread_id();
|
||||
const ck_tile::long_index_t num_threads = blockDim.x * gridDim.x;
|
||||
|
||||
// Calculate total input elements
|
||||
ck_tile::long_index_t input_length = G * N * C;
|
||||
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
input_length *= in_spatial_lengths[i];
|
||||
}
|
||||
|
||||
// Calculate strides for input tensor (NDHWGC or NHWGC or NWGC)
|
||||
std::array<ck_tile::long_index_t, NDimSpatial + 3> in_strides;
|
||||
ck_tile::long_index_t stride = 1;
|
||||
in_strides[NDimSpatial + 2] = stride; // C stride
|
||||
stride *= C;
|
||||
in_strides[NDimSpatial + 1] = stride; // G stride
|
||||
stride *= G;
|
||||
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
|
||||
{
|
||||
in_strides[i + 1] = stride;
|
||||
stride *= in_spatial_lengths[i];
|
||||
}
|
||||
in_strides[0] = stride; // N stride
|
||||
|
||||
// Calculate strides for output tensor (NDHWGK or NHWGK or NWGK)
|
||||
std::array<ck_tile::long_index_t, NDimSpatial + 3> out_strides;
|
||||
stride = 1;
|
||||
out_strides[NDimSpatial + 2] = stride; // K stride
|
||||
stride *= K;
|
||||
out_strides[NDimSpatial + 1] = stride; // G stride
|
||||
stride *= G;
|
||||
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
|
||||
{
|
||||
out_strides[i + 1] = stride;
|
||||
stride *= out_spatial_lengths[i];
|
||||
}
|
||||
out_strides[0] = stride; // N stride
|
||||
|
||||
// Calculate strides for weight tensor (GKZYXC or GKYXC or GKXC)
|
||||
std::array<ck_tile::long_index_t, NDimSpatial + 3> wei_strides;
|
||||
stride = 1;
|
||||
wei_strides[NDimSpatial + 2] = stride; // C stride
|
||||
stride *= C;
|
||||
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
|
||||
{
|
||||
wei_strides[i + 2] = stride;
|
||||
stride *= wei_spatial_lengths[i];
|
||||
}
|
||||
wei_strides[1] = stride; // K stride
|
||||
stride *= K;
|
||||
wei_strides[0] = stride; // G stride
|
||||
|
||||
// Grid-stride loop over all input elements
|
||||
for(ck_tile::long_index_t ii = tid; ii < input_length; ii += num_threads)
|
||||
{
|
||||
// Decode linear index to multi-dimensional indices
|
||||
ck_tile::long_index_t tmp = ii;
|
||||
|
||||
// Extract N (batch)
|
||||
ck_tile::index_t n = tmp / in_strides[0];
|
||||
tmp -= n * in_strides[0];
|
||||
|
||||
// Extract spatial dimensions
|
||||
ck_tile::index_t in_spatial_idx[6];
|
||||
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
in_spatial_idx[i] = tmp / in_strides[i + 1];
|
||||
tmp -= in_spatial_idx[i] * in_strides[i + 1];
|
||||
}
|
||||
|
||||
// Extract G (group)
|
||||
ck_tile::index_t g = tmp / in_strides[NDimSpatial + 1];
|
||||
tmp -= g * in_strides[NDimSpatial + 1];
|
||||
|
||||
// Extract C (input channel)
|
||||
ck_tile::index_t c = tmp;
|
||||
|
||||
// Accumulate in float
|
||||
float v_acc = 0.0f;
|
||||
|
||||
// Loop over output channels
|
||||
for(ck_tile::index_t k = 0; k < K; ++k)
|
||||
{
|
||||
// Loop over filter spatial dimensions
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[0]; ++x)
|
||||
{
|
||||
// Calculate output spatial coordinate (inverse of forward)
|
||||
ck_tile::long_index_t w_tmp =
|
||||
static_cast<ck_tile::long_index_t>(in_spatial_idx[0]) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]);
|
||||
|
||||
// Check if this maps to valid output position
|
||||
if(w_tmp % conv_strides[0] == 0)
|
||||
{
|
||||
ck_tile::long_index_t wo = w_tmp / conv_strides[0];
|
||||
|
||||
if(wo >= 0 && wo < out_spatial_lengths[0])
|
||||
{
|
||||
std::array<ck_tile::index_t, 1> out_spatial = {
|
||||
static_cast<index_t>(wo)};
|
||||
std::array<ck_tile::index_t, 1> wei_spatial = {x};
|
||||
ck_tile::long_index_t out_idx = detail::calculate_output_index<1>(
|
||||
n, g, k, out_spatial, out_strides);
|
||||
ck_tile::long_index_t wei_idx = detail::calculate_weight_index<1>(
|
||||
g, k, c, wei_spatial, wei_strides);
|
||||
|
||||
v_acc += type_convert<float>(p_out_grad[out_idx]) *
|
||||
type_convert<float>(p_wei[wei_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
for(ck_tile::index_t y = 0; y < wei_spatial_lengths[0]; ++y)
|
||||
{
|
||||
ck_tile::long_index_t h_tmp =
|
||||
static_cast<ck_tile::long_index_t>(in_spatial_idx[0]) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]);
|
||||
|
||||
if(h_tmp % conv_strides[0] == 0)
|
||||
{
|
||||
ck_tile::long_index_t ho = h_tmp / conv_strides[0];
|
||||
|
||||
if(ho >= 0 && ho < out_spatial_lengths[0])
|
||||
{
|
||||
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[1]; ++x)
|
||||
{
|
||||
ck_tile::long_index_t w_tmp =
|
||||
static_cast<ck_tile::long_index_t>(in_spatial_idx[1]) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]);
|
||||
|
||||
if(w_tmp % conv_strides[1] == 0)
|
||||
{
|
||||
ck_tile::long_index_t wo = w_tmp / conv_strides[1];
|
||||
|
||||
if(wo >= 0 && wo < out_spatial_lengths[1])
|
||||
{
|
||||
std::array<ck_tile::index_t, 2> out_spatial = {
|
||||
static_cast<index_t>(ho), static_cast<index_t>(wo)};
|
||||
std::array<ck_tile::index_t, 2> wei_spatial = {y, x};
|
||||
ck_tile::long_index_t out_idx =
|
||||
detail::calculate_output_index<2>(
|
||||
n, g, k, out_spatial, out_strides);
|
||||
ck_tile::long_index_t wei_idx =
|
||||
detail::calculate_weight_index<2>(
|
||||
g, k, c, wei_spatial, wei_strides);
|
||||
|
||||
v_acc += type_convert<float>(p_out_grad[out_idx]) *
|
||||
type_convert<float>(p_wei[wei_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
for(ck_tile::index_t z = 0; z < wei_spatial_lengths[0]; ++z)
|
||||
{
|
||||
ck_tile::long_index_t d_tmp =
|
||||
static_cast<ck_tile::long_index_t>(in_spatial_idx[0]) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]) -
|
||||
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]);
|
||||
|
||||
if(d_tmp % conv_strides[0] == 0)
|
||||
{
|
||||
ck_tile::long_index_t do_ = d_tmp / conv_strides[0];
|
||||
|
||||
if(do_ >= 0 && do_ < out_spatial_lengths[0])
|
||||
{
|
||||
for(ck_tile::index_t y = 0; y < wei_spatial_lengths[1]; ++y)
|
||||
{
|
||||
ck_tile::long_index_t h_tmp =
|
||||
static_cast<ck_tile::long_index_t>(in_spatial_idx[1]) +
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]) -
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]);
|
||||
|
||||
if(h_tmp % conv_strides[1] == 0)
|
||||
{
|
||||
ck_tile::long_index_t ho = h_tmp / conv_strides[1];
|
||||
|
||||
if(ho >= 0 && ho < out_spatial_lengths[1])
|
||||
{
|
||||
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[2];
|
||||
++x)
|
||||
{
|
||||
ck_tile::long_index_t w_tmp =
|
||||
static_cast<ck_tile::long_index_t>(
|
||||
in_spatial_idx[2]) +
|
||||
static_cast<ck_tile::long_index_t>(
|
||||
in_left_pads[2]) -
|
||||
static_cast<ck_tile::long_index_t>(
|
||||
x * conv_dilations[2]);
|
||||
|
||||
if(w_tmp % conv_strides[2] == 0)
|
||||
{
|
||||
ck_tile::long_index_t wo =
|
||||
w_tmp / conv_strides[2];
|
||||
|
||||
if(wo >= 0 && wo < out_spatial_lengths[2])
|
||||
{
|
||||
std::array<ck_tile::index_t, 3>
|
||||
out_spatial = {
|
||||
static_cast<index_t>(do_),
|
||||
static_cast<index_t>(ho),
|
||||
static_cast<index_t>(wo)};
|
||||
std::array<ck_tile::index_t, 3>
|
||||
wei_spatial = {z, y, x};
|
||||
ck_tile::long_index_t out_idx =
|
||||
detail::calculate_output_index<3>(
|
||||
n, g, k, out_spatial, out_strides);
|
||||
ck_tile::long_index_t wei_idx =
|
||||
detail::calculate_weight_index<3>(
|
||||
g, k, c, wei_spatial, wei_strides);
|
||||
|
||||
v_acc +=
|
||||
type_convert<float>(
|
||||
p_out_grad[out_idx]) *
|
||||
type_convert<float>(p_wei[wei_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert accumulator to output type and write
|
||||
p_in_grad[ii] = type_convert<InDataType>(v_acc);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Host-side launcher for naive grouped convolution backward data
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
CK_TILE_HOST float
|
||||
naive_grouped_conv_bwd_data(InDataType* p_in_grad_dev,
|
||||
const WeiDataType* p_wei_dev,
|
||||
const OutDataType* p_out_grad_dev,
|
||||
ck_tile::index_t G,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t C,
|
||||
std::vector<ck_tile::long_index_t> in_spatial_lengths,
|
||||
std::vector<ck_tile::long_index_t> wei_spatial_lengths,
|
||||
std::vector<ck_tile::long_index_t> out_spatial_lengths,
|
||||
std::vector<ck_tile::long_index_t> conv_strides,
|
||||
std::vector<ck_tile::long_index_t> conv_dilations,
|
||||
std::vector<ck_tile::long_index_t> in_left_pads,
|
||||
ck_tile::stream_config stream_config = {})
|
||||
{
|
||||
// Convert vectors to arrays
|
||||
auto in_spatial_arr = to_array_with_default<NDimSpatial>(in_spatial_lengths);
|
||||
auto wei_spatial_arr = to_array_with_default<NDimSpatial>(wei_spatial_lengths);
|
||||
auto out_spatial_arr = to_array_with_default<NDimSpatial>(out_spatial_lengths);
|
||||
auto conv_strides_arr = to_array_with_default<NDimSpatial>(conv_strides);
|
||||
auto conv_dilations_arr = to_array_with_default<NDimSpatial>(conv_dilations);
|
||||
auto in_left_pads_arr = to_array_with_default<NDimSpatial>(in_left_pads, 0);
|
||||
|
||||
// Calculate grid size
|
||||
ck_tile::long_index_t input_length = G * N * C;
|
||||
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
input_length *= in_spatial_lengths[i];
|
||||
}
|
||||
|
||||
using KernelType =
|
||||
naive_grouped_conv_bwd_data_kernel<NDimSpatial, InDataType, WeiDataType, OutDataType>;
|
||||
|
||||
constexpr ck_tile::index_t block_size = KernelType::kBlockSize;
|
||||
const ck_tile::index_t grid_size = (input_length + block_size - 1) / block_size;
|
||||
|
||||
// Launch kernel
|
||||
float elapsed_ms = launch_kernel(stream_config,
|
||||
make_kernel(KernelType{},
|
||||
dim3(grid_size),
|
||||
dim3(block_size),
|
||||
0, // dynamic shared memory size
|
||||
p_in_grad_dev,
|
||||
p_wei_dev,
|
||||
p_out_grad_dev,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
in_spatial_arr,
|
||||
wei_spatial_arr,
|
||||
out_spatial_arr,
|
||||
conv_strides_arr,
|
||||
conv_dilations_arr,
|
||||
in_left_pads_arr));
|
||||
|
||||
return elapsed_ms;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
324
include/ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp
Normal file
324
include/ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp
Normal file
@@ -0,0 +1,324 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ref/conv_common.hpp"
|
||||
#include <array>
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Naive GPU reference kernel struct for backward weight grouped convolution
|
||||
// Computes gradient with respect to weights
|
||||
// Layout: Input=NDHWGC, Weight_grad=GKZYXC, Output_grad=NDHWGK (for 3D case)
|
||||
// Input=NHWGC, Weight_grad=GKYXC, Output_grad=NHWGK (for 2D case)
|
||||
// Input=NWGC, Weight_grad=GKXC, Output_grad=NWGK (for 1D case)
|
||||
//
|
||||
// One thread per weight element, uses grid-stride loop pattern
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
struct naive_grouped_conv_bwd_weight_kernel
|
||||
{
|
||||
static constexpr ck_tile::index_t kBlockSize = 256;
|
||||
|
||||
__device__ void
|
||||
operator()(const InDataType* __restrict__ p_in,
|
||||
WeiDataType* __restrict__ p_wei_grad,
|
||||
const OutDataType* __restrict__ p_out_grad,
|
||||
// Tensor dimensions
|
||||
ck_tile::index_t G, // number of groups
|
||||
ck_tile::index_t N, // batch size
|
||||
ck_tile::index_t K, // output channels per group
|
||||
ck_tile::index_t C, // input channels per group
|
||||
// Input spatial dimensions
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& in_spatial_lengths,
|
||||
// Weight spatial dimensions
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& wei_spatial_lengths,
|
||||
// Output spatial dimensions
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& out_spatial_lengths,
|
||||
// Convolution parameters
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_strides,
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_dilations,
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& in_left_pads) const
|
||||
{
|
||||
const ck_tile::long_index_t tid = get_block_id() * blockDim.x + get_thread_id();
|
||||
const ck_tile::long_index_t num_threads = blockDim.x * gridDim.x;
|
||||
|
||||
// Calculate total weight elements
|
||||
ck_tile::long_index_t weight_length = G * K * C;
|
||||
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
weight_length *= wei_spatial_lengths[i];
|
||||
}
|
||||
|
||||
// Calculate strides for weight tensor (GKZYXC or GKYXC or GKXC)
|
||||
std::array<ck_tile::long_index_t, NDimSpatial + 3> wei_strides;
|
||||
ck_tile::long_index_t stride = 1;
|
||||
wei_strides[NDimSpatial + 2] = stride; // C stride
|
||||
stride *= C;
|
||||
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
|
||||
{
|
||||
wei_strides[i + 2] = stride;
|
||||
stride *= wei_spatial_lengths[i];
|
||||
}
|
||||
wei_strides[1] = stride; // K stride
|
||||
stride *= K;
|
||||
wei_strides[0] = stride; // G stride
|
||||
|
||||
// Calculate strides for input tensor (NDHWGC or NHWGC or NWGC)
|
||||
std::array<ck_tile::long_index_t, NDimSpatial + 3> in_strides;
|
||||
stride = 1;
|
||||
in_strides[NDimSpatial + 2] = stride; // C stride
|
||||
stride *= C;
|
||||
in_strides[NDimSpatial + 1] = stride; // G stride
|
||||
stride *= G;
|
||||
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
|
||||
{
|
||||
in_strides[i + 1] = stride;
|
||||
stride *= in_spatial_lengths[i];
|
||||
}
|
||||
in_strides[0] = stride; // N stride
|
||||
|
||||
// Calculate strides for output tensor (NDHWGK or NHWGK or NWGK)
|
||||
std::array<ck_tile::long_index_t, NDimSpatial + 3> out_strides;
|
||||
stride = 1;
|
||||
out_strides[NDimSpatial + 2] = stride; // K stride
|
||||
stride *= K;
|
||||
out_strides[NDimSpatial + 1] = stride; // G stride
|
||||
stride *= G;
|
||||
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
|
||||
{
|
||||
out_strides[i + 1] = stride;
|
||||
stride *= out_spatial_lengths[i];
|
||||
}
|
||||
out_strides[0] = stride; // N stride
|
||||
|
||||
// Grid-stride loop over all weight elements
|
||||
for(ck_tile::long_index_t ii = tid; ii < weight_length; ii += num_threads)
|
||||
{
|
||||
// Decode linear index to multi-dimensional indices
|
||||
ck_tile::long_index_t tmp = ii;
|
||||
|
||||
// Extract G (group)
|
||||
ck_tile::index_t g = tmp / wei_strides[0];
|
||||
tmp -= g * wei_strides[0];
|
||||
|
||||
// Extract K (output channel)
|
||||
ck_tile::index_t k = tmp / wei_strides[1];
|
||||
tmp -= k * wei_strides[1];
|
||||
|
||||
// Extract spatial dimensions (come before C in GKZYXC layout)
|
||||
ck_tile::index_t wei_spatial_idx[6];
|
||||
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
wei_spatial_idx[i] = tmp / wei_strides[i + 2];
|
||||
tmp -= wei_spatial_idx[i] * wei_strides[i + 2];
|
||||
}
|
||||
|
||||
// Extract C (input channel) - comes last
|
||||
ck_tile::index_t c = tmp;
|
||||
|
||||
// Accumulate in float
|
||||
float v_acc = 0.0f;
|
||||
|
||||
// Loop over batch
|
||||
for(ck_tile::index_t n = 0; n < N; ++n)
|
||||
{
|
||||
// Loop over output spatial dimensions
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
for(ck_tile::index_t wo = 0; wo < out_spatial_lengths[0]; ++wo)
|
||||
{
|
||||
// Calculate input spatial coordinate
|
||||
ck_tile::long_index_t wi =
|
||||
static_cast<ck_tile::long_index_t>(wo * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(wei_spatial_idx[0] *
|
||||
conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
// Bounds check
|
||||
if(wi >= 0 && wi < in_spatial_lengths[0])
|
||||
{
|
||||
std::array<ck_tile::index_t, 1> in_spatial = {static_cast<index_t>(wi)};
|
||||
std::array<ck_tile::index_t, 1> out_spatial = {
|
||||
static_cast<index_t>(wo)};
|
||||
ck_tile::long_index_t in_idx =
|
||||
detail::calculate_input_index<1>(n, g, c, in_spatial, in_strides);
|
||||
ck_tile::long_index_t out_idx = detail::calculate_output_index<1>(
|
||||
n, g, k, out_spatial, out_strides);
|
||||
|
||||
v_acc += type_convert<float>(p_out_grad[out_idx]) *
|
||||
type_convert<float>(p_in[in_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
for(ck_tile::index_t ho = 0; ho < out_spatial_lengths[0]; ++ho)
|
||||
{
|
||||
ck_tile::long_index_t hi =
|
||||
static_cast<ck_tile::long_index_t>(ho * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(wei_spatial_idx[0] *
|
||||
conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
for(ck_tile::index_t wo = 0; wo < out_spatial_lengths[1]; ++wo)
|
||||
{
|
||||
ck_tile::long_index_t wi =
|
||||
static_cast<ck_tile::long_index_t>(wo * conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(wei_spatial_idx[1] *
|
||||
conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
|
||||
// Bounds check
|
||||
if(hi >= 0 && hi < in_spatial_lengths[0] && wi >= 0 &&
|
||||
wi < in_spatial_lengths[1])
|
||||
{
|
||||
std::array<ck_tile::index_t, 2> in_spatial = {
|
||||
static_cast<index_t>(hi), static_cast<index_t>(wi)};
|
||||
std::array<ck_tile::index_t, 2> out_spatial = {
|
||||
static_cast<index_t>(ho), static_cast<index_t>(wo)};
|
||||
ck_tile::long_index_t in_idx = detail::calculate_input_index<2>(
|
||||
n, g, c, in_spatial, in_strides);
|
||||
ck_tile::long_index_t out_idx = detail::calculate_output_index<2>(
|
||||
n, g, k, out_spatial, out_strides);
|
||||
|
||||
v_acc += type_convert<float>(p_out_grad[out_idx]) *
|
||||
type_convert<float>(p_in[in_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
for(ck_tile::index_t do_ = 0; do_ < out_spatial_lengths[0]; ++do_)
|
||||
{
|
||||
ck_tile::long_index_t di =
|
||||
static_cast<ck_tile::long_index_t>(do_ * conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(wei_spatial_idx[0] *
|
||||
conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
for(ck_tile::index_t ho = 0; ho < out_spatial_lengths[1]; ++ho)
|
||||
{
|
||||
ck_tile::long_index_t hi =
|
||||
static_cast<ck_tile::long_index_t>(ho * conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(wei_spatial_idx[1] *
|
||||
conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
|
||||
for(ck_tile::index_t wo = 0; wo < out_spatial_lengths[2]; ++wo)
|
||||
{
|
||||
ck_tile::long_index_t wi =
|
||||
static_cast<ck_tile::long_index_t>(wo * conv_strides[2]) +
|
||||
static_cast<ck_tile::long_index_t>(wei_spatial_idx[2] *
|
||||
conv_dilations[2]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
|
||||
|
||||
// Bounds check
|
||||
if(di >= 0 && di < in_spatial_lengths[0] && hi >= 0 &&
|
||||
hi < in_spatial_lengths[1] && wi >= 0 &&
|
||||
wi < in_spatial_lengths[2])
|
||||
{
|
||||
std::array<ck_tile::index_t, 3> in_spatial = {
|
||||
static_cast<index_t>(di),
|
||||
static_cast<index_t>(hi),
|
||||
static_cast<index_t>(wi)};
|
||||
std::array<ck_tile::index_t, 3> out_spatial = {
|
||||
static_cast<index_t>(do_),
|
||||
static_cast<index_t>(ho),
|
||||
static_cast<index_t>(wo)};
|
||||
ck_tile::long_index_t in_idx = detail::calculate_input_index<3>(
|
||||
n, g, c, in_spatial, in_strides);
|
||||
ck_tile::long_index_t out_idx =
|
||||
detail::calculate_output_index<3>(
|
||||
n, g, k, out_spatial, out_strides);
|
||||
|
||||
v_acc += type_convert<float>(p_out_grad[out_idx]) *
|
||||
type_convert<float>(p_in[in_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert accumulator to output type and write
|
||||
p_wei_grad[ii] = type_convert<WeiDataType>(v_acc);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Host-side launcher for naive grouped convolution backward weight
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
CK_TILE_HOST float
|
||||
naive_grouped_conv_bwd_weight(const InDataType* p_in_dev,
|
||||
WeiDataType* p_wei_grad_dev,
|
||||
const OutDataType* p_out_grad_dev,
|
||||
ck_tile::index_t G,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t C,
|
||||
std::vector<ck_tile::long_index_t> in_spatial_lengths,
|
||||
std::vector<ck_tile::long_index_t> wei_spatial_lengths,
|
||||
std::vector<ck_tile::long_index_t> out_spatial_lengths,
|
||||
std::vector<ck_tile::long_index_t> conv_strides,
|
||||
std::vector<ck_tile::long_index_t> conv_dilations,
|
||||
std::vector<ck_tile::long_index_t> in_left_pads,
|
||||
ck_tile::stream_config stream_config = {})
|
||||
{
|
||||
// Convert vectors to arrays
|
||||
auto in_spatial_arr = to_array_with_default<NDimSpatial>(in_spatial_lengths);
|
||||
auto wei_spatial_arr = to_array_with_default<NDimSpatial>(wei_spatial_lengths);
|
||||
auto out_spatial_arr = to_array_with_default<NDimSpatial>(out_spatial_lengths);
|
||||
auto conv_strides_arr = to_array_with_default<NDimSpatial>(conv_strides);
|
||||
auto conv_dilations_arr = to_array_with_default<NDimSpatial>(conv_dilations);
|
||||
auto in_left_pads_arr = to_array_with_default<NDimSpatial>(in_left_pads, 0);
|
||||
|
||||
// Calculate grid size
|
||||
ck_tile::long_index_t weight_length = G * K * C;
|
||||
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
weight_length *= wei_spatial_lengths[i];
|
||||
}
|
||||
|
||||
using KernelType =
|
||||
naive_grouped_conv_bwd_weight_kernel<NDimSpatial, InDataType, WeiDataType, OutDataType>;
|
||||
|
||||
constexpr ck_tile::index_t block_size = KernelType::kBlockSize;
|
||||
const ck_tile::index_t grid_size = (weight_length + block_size - 1) / block_size;
|
||||
|
||||
// Launch kernel
|
||||
float elapsed_ms = launch_kernel(stream_config,
|
||||
make_kernel(KernelType{},
|
||||
dim3(grid_size),
|
||||
dim3(block_size),
|
||||
0, // dynamic shared memory size
|
||||
p_in_dev,
|
||||
p_wei_grad_dev,
|
||||
p_out_grad_dev,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
in_spatial_arr,
|
||||
wei_spatial_arr,
|
||||
out_spatial_arr,
|
||||
conv_strides_arr,
|
||||
conv_dilations_arr,
|
||||
in_left_pads_arr));
|
||||
|
||||
return elapsed_ms;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
317
include/ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp
Normal file
317
include/ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp
Normal file
@@ -0,0 +1,317 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ref/conv_common.hpp"
|
||||
#include <array>
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Naive GPU reference kernel struct for forward grouped convolution
|
||||
// Layout: Input=NDHWGC, Weight=GKZYXC, Output=NDHWGK (for 3D case)
|
||||
// Input=NHWGC, Weight=GKYXC, Output=NHWGK (for 2D case)
|
||||
// Input=NWGC, Weight=GKXC, Output=NWGK (for 1D case)
|
||||
//
|
||||
// One thread per output element, uses grid-stride loop pattern
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
struct naive_grouped_conv_fwd_kernel
|
||||
{
|
||||
static constexpr ck_tile::index_t kBlockSize = 256;
|
||||
|
||||
__device__ void
|
||||
operator()(const InDataType* __restrict__ p_in,
|
||||
const WeiDataType* __restrict__ p_wei,
|
||||
OutDataType* __restrict__ p_out,
|
||||
// Tensor dimensions
|
||||
ck_tile::index_t G, // number of groups
|
||||
ck_tile::index_t N, // batch size
|
||||
ck_tile::index_t K, // output channels per group
|
||||
ck_tile::index_t C, // input channels per group
|
||||
// Input spatial dimensions
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& in_spatial_lengths,
|
||||
// Weight spatial dimensions
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& wei_spatial_lengths,
|
||||
// Output spatial dimensions
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& out_spatial_lengths,
|
||||
// Convolution parameters
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_strides,
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& conv_dilations,
|
||||
const std::array<ck_tile::long_index_t, NDimSpatial>& in_left_pads) const
|
||||
{
|
||||
const ck_tile::long_index_t tid = get_block_id() * blockDim.x + get_thread_id();
|
||||
const ck_tile::long_index_t num_threads = blockDim.x * gridDim.x;
|
||||
|
||||
// Calculate total output elements
|
||||
ck_tile::long_index_t output_length = G * N * K;
|
||||
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
output_length *= out_spatial_lengths[i];
|
||||
}
|
||||
|
||||
// Calculate strides for output tensor (NDHWGK or NHWGK or NWGK)
|
||||
std::array<ck_tile::long_index_t, NDimSpatial + 3> out_strides; // N, spatial dims, G, K
|
||||
ck_tile::long_index_t stride = 1;
|
||||
out_strides[NDimSpatial + 2] = stride; // K stride
|
||||
stride *= K;
|
||||
out_strides[NDimSpatial + 1] = stride; // G stride
|
||||
stride *= G;
|
||||
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i) // Spatial strides (reversed)
|
||||
{
|
||||
out_strides[i + 1] = stride;
|
||||
stride *= out_spatial_lengths[i];
|
||||
}
|
||||
out_strides[0] = stride; // N stride
|
||||
|
||||
// Calculate strides for input tensor (NDHWGC or NHWGC or NWGC)
|
||||
std::array<ck_tile::long_index_t, NDimSpatial + 3> in_strides;
|
||||
stride = 1;
|
||||
in_strides[NDimSpatial + 2] = stride; // C stride
|
||||
stride *= C;
|
||||
in_strides[NDimSpatial + 1] = stride; // G stride
|
||||
stride *= G;
|
||||
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
|
||||
{
|
||||
in_strides[i + 1] = stride;
|
||||
stride *= in_spatial_lengths[i];
|
||||
}
|
||||
in_strides[0] = stride; // N stride
|
||||
|
||||
// Calculate strides for weight tensor (GKZYXC or GKYXC or GKXC)
|
||||
std::array<ck_tile::long_index_t, NDimSpatial + 3> wei_strides;
|
||||
stride = 1;
|
||||
wei_strides[NDimSpatial + 2] = stride; // C stride
|
||||
stride *= C;
|
||||
for(ck_tile::index_t i = NDimSpatial - 1; i >= 0; --i)
|
||||
{
|
||||
wei_strides[i + 2] = stride;
|
||||
stride *= wei_spatial_lengths[i];
|
||||
}
|
||||
wei_strides[1] = stride; // K stride
|
||||
stride *= K;
|
||||
wei_strides[0] = stride; // G stride
|
||||
|
||||
// Grid-stride loop over all output elements
|
||||
for(ck_tile::long_index_t ii = tid; ii < output_length; ii += num_threads)
|
||||
{
|
||||
// Decode linear index to multi-dimensional indices
|
||||
ck_tile::long_index_t tmp = ii;
|
||||
|
||||
// Extract N (batch)
|
||||
ck_tile::index_t n = tmp / out_strides[0];
|
||||
tmp -= n * out_strides[0];
|
||||
|
||||
// Extract spatial dimensions (D, H, W)
|
||||
ck_tile::index_t out_spatial_idx[6]; // Max 6 spatial dimensions
|
||||
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
out_spatial_idx[i] = tmp / out_strides[i + 1];
|
||||
tmp -= out_spatial_idx[i] * out_strides[i + 1];
|
||||
}
|
||||
|
||||
// Extract G (group)
|
||||
ck_tile::index_t g = tmp / out_strides[NDimSpatial + 1];
|
||||
tmp -= g * out_strides[NDimSpatial + 1];
|
||||
|
||||
// Extract K (output channel)
|
||||
ck_tile::index_t k = tmp;
|
||||
|
||||
// Accumulate in float
|
||||
float v_acc = 0.0f;
|
||||
|
||||
// Loop over input channels
|
||||
for(ck_tile::index_t c = 0; c < C; ++c)
|
||||
{
|
||||
// Loop over filter spatial dimensions
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[0]; ++x)
|
||||
{
|
||||
// Calculate input spatial coordinate
|
||||
ck_tile::long_index_t wi =
|
||||
static_cast<ck_tile::long_index_t>(out_spatial_idx[0] *
|
||||
conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
// Bounds check
|
||||
if(wi >= 0 && wi < in_spatial_lengths[0])
|
||||
{
|
||||
std::array<ck_tile::index_t, 1> in_spatial = {static_cast<index_t>(wi)};
|
||||
std::array<ck_tile::index_t, 1> wei_spatial = {x};
|
||||
ck_tile::long_index_t in_idx =
|
||||
detail::calculate_input_index<1>(n, g, c, in_spatial, in_strides);
|
||||
ck_tile::long_index_t wei_idx = detail::calculate_weight_index<1>(
|
||||
g, k, c, wei_spatial, wei_strides);
|
||||
|
||||
v_acc += type_convert<float>(p_in[in_idx]) *
|
||||
type_convert<float>(p_wei[wei_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
for(ck_tile::index_t y = 0; y < wei_spatial_lengths[0]; ++y)
|
||||
{
|
||||
ck_tile::long_index_t hi =
|
||||
static_cast<ck_tile::long_index_t>(out_spatial_idx[0] *
|
||||
conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[1]; ++x)
|
||||
{
|
||||
ck_tile::long_index_t wi =
|
||||
static_cast<ck_tile::long_index_t>(out_spatial_idx[1] *
|
||||
conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
|
||||
// Bounds check
|
||||
if(hi >= 0 && hi < in_spatial_lengths[0] && wi >= 0 &&
|
||||
wi < in_spatial_lengths[1])
|
||||
{
|
||||
std::array<ck_tile::index_t, 2> in_spatial = {
|
||||
static_cast<index_t>(hi), static_cast<index_t>(wi)};
|
||||
std::array<ck_tile::index_t, 2> wei_spatial = {y, x};
|
||||
ck_tile::long_index_t in_idx = detail::calculate_input_index<2>(
|
||||
n, g, c, in_spatial, in_strides);
|
||||
ck_tile::long_index_t wei_idx = detail::calculate_weight_index<2>(
|
||||
g, k, c, wei_spatial, wei_strides);
|
||||
|
||||
v_acc += type_convert<float>(p_in[in_idx]) *
|
||||
type_convert<float>(p_wei[wei_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
for(ck_tile::index_t z = 0; z < wei_spatial_lengths[0]; ++z)
|
||||
{
|
||||
ck_tile::long_index_t di =
|
||||
static_cast<ck_tile::long_index_t>(out_spatial_idx[0] *
|
||||
conv_strides[0]) +
|
||||
static_cast<ck_tile::long_index_t>(z * conv_dilations[0]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[0]);
|
||||
|
||||
for(ck_tile::index_t y = 0; y < wei_spatial_lengths[1]; ++y)
|
||||
{
|
||||
ck_tile::long_index_t hi =
|
||||
static_cast<ck_tile::long_index_t>(out_spatial_idx[1] *
|
||||
conv_strides[1]) +
|
||||
static_cast<ck_tile::long_index_t>(y * conv_dilations[1]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[1]);
|
||||
|
||||
for(ck_tile::index_t x = 0; x < wei_spatial_lengths[2]; ++x)
|
||||
{
|
||||
ck_tile::long_index_t wi =
|
||||
static_cast<ck_tile::long_index_t>(out_spatial_idx[2] *
|
||||
conv_strides[2]) +
|
||||
static_cast<ck_tile::long_index_t>(x * conv_dilations[2]) -
|
||||
static_cast<ck_tile::long_index_t>(in_left_pads[2]);
|
||||
|
||||
// Bounds check
|
||||
if(di >= 0 && di < in_spatial_lengths[0] && hi >= 0 &&
|
||||
hi < in_spatial_lengths[1] && wi >= 0 &&
|
||||
wi < in_spatial_lengths[2])
|
||||
{
|
||||
std::array<ck_tile::index_t, 3> in_spatial = {
|
||||
static_cast<index_t>(di),
|
||||
static_cast<index_t>(hi),
|
||||
static_cast<index_t>(wi)};
|
||||
std::array<ck_tile::index_t, 3> wei_spatial = {z, y, x};
|
||||
ck_tile::long_index_t in_idx = detail::calculate_input_index<3>(
|
||||
n, g, c, in_spatial, in_strides);
|
||||
ck_tile::long_index_t wei_idx =
|
||||
detail::calculate_weight_index<3>(
|
||||
g, k, c, wei_spatial, wei_strides);
|
||||
|
||||
v_acc += type_convert<float>(p_in[in_idx]) *
|
||||
type_convert<float>(p_wei[wei_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert accumulator to output type and write
|
||||
p_out[ii] = type_convert<OutDataType>(v_acc);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Host-side launcher for naive grouped convolution forward
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
CK_TILE_HOST float naive_grouped_conv_fwd(const InDataType* p_in_dev,
|
||||
const WeiDataType* p_wei_dev,
|
||||
OutDataType* p_out_dev,
|
||||
ck_tile::index_t G,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t C,
|
||||
std::vector<ck_tile::long_index_t> in_spatial_lengths,
|
||||
std::vector<ck_tile::long_index_t> wei_spatial_lengths,
|
||||
std::vector<ck_tile::long_index_t> out_spatial_lengths,
|
||||
std::vector<ck_tile::long_index_t> conv_strides,
|
||||
std::vector<ck_tile::long_index_t> conv_dilations,
|
||||
std::vector<ck_tile::long_index_t> in_left_pads,
|
||||
ck_tile::stream_config stream_config = {})
|
||||
{
|
||||
// Convert vectors to arrays (std::array can be passed by value to kernel)
|
||||
auto in_spatial_arr = to_array_with_default<NDimSpatial>(in_spatial_lengths);
|
||||
auto wei_spatial_arr = to_array_with_default<NDimSpatial>(wei_spatial_lengths);
|
||||
auto out_spatial_arr = to_array_with_default<NDimSpatial>(out_spatial_lengths);
|
||||
auto conv_strides_arr = to_array_with_default<NDimSpatial>(conv_strides);
|
||||
auto conv_dilations_arr = to_array_with_default<NDimSpatial>(conv_dilations);
|
||||
auto in_left_pads_arr = to_array_with_default<NDimSpatial>(in_left_pads, 0);
|
||||
|
||||
// Calculate grid size
|
||||
ck_tile::long_index_t output_length = G * N * K;
|
||||
for(ck_tile::index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
output_length *= out_spatial_lengths[i];
|
||||
}
|
||||
|
||||
using KernelType =
|
||||
naive_grouped_conv_fwd_kernel<NDimSpatial, InDataType, WeiDataType, OutDataType>;
|
||||
|
||||
constexpr ck_tile::index_t block_size = KernelType::kBlockSize;
|
||||
const ck_tile::index_t grid_size = (output_length + block_size - 1) / block_size;
|
||||
|
||||
// Launch kernel
|
||||
float elapsed_ms = launch_kernel(stream_config,
|
||||
make_kernel(KernelType{},
|
||||
dim3(grid_size),
|
||||
dim3(block_size),
|
||||
0, // dynamic shared memory size
|
||||
p_in_dev,
|
||||
p_wei_dev,
|
||||
p_out_dev,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
in_spatial_arr,
|
||||
wei_spatial_arr,
|
||||
out_spatial_arr,
|
||||
conv_strides_arr,
|
||||
conv_dilations_arr,
|
||||
in_left_pads_arr));
|
||||
|
||||
return elapsed_ms;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user