Merge commit 'de61e554938265a5d17a1bba8c148457125e80cd' into develop

This commit is contained in:
assistant-librarian[bot]
2025-08-25 13:20:15 +00:00
parent 2b1a426db5
commit bb07450f2c
10 changed files with 217 additions and 151 deletions

View File

@@ -1833,14 +1833,17 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
if constexpr(oob_conditional_check)
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
// Use C-style cast to change address space without dropping llvm noalias attribute
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
(as3_uint32_ptr)(smem),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
#pragma clang diagnostic pop
}
template <index_t N,
@@ -2788,23 +2791,26 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
{
#define __LDS_ADDR __attribute__((address_space(3)))
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
"We need to have the compatible compiler version to build this instruction");
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
// Use C-style cast to change address space without dropping llvm noalias attribute
const auto in_ptr_ = (__LDS_ADDR T*)(const_cast<T*>(in_ptr));
#pragma clang diagnostic pop
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_fp16x4_t*>(in_ptr_);
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_bf16x4_t*>(in_ptr_);
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
@@ -2812,15 +2818,14 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
{
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
__attribute__((address_space(3))) llvm_i32x2_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_);
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}
else
{
static_assert(false, "not implemented");
}
#undef __LDS_ADDR
}
#endif

View File

@@ -1603,14 +1603,17 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
if constexpr(oob_conditional_check)
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
reinterpret_cast<as3_uint32_ptr>(reinterpret_cast<uintptr_t>(smem)),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
// Use C-style cast to change address space without dropping llvm noalias attribute
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
(as3_uint32_ptr)(smem),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
#pragma clang diagnostic pop
}
template <index_t N,
@@ -2606,23 +2609,26 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
{
#define __LDS_ADDR __attribute__((address_space(3)))
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
"We need to have the compatible compiler version to build this instruction");
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
// Use C-style cast to change address space without dropping llvm noalias attribute
const auto in_ptr_ = (__LDS_ADDR T*)(const_cast<T*>(in_ptr));
#pragma clang diagnostic pop
if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::half_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_fp16x4_t*>(in_ptr_);
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_bf16x4_t*>(in_ptr_);
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||
@@ -2630,15 +2636,14 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
std::is_same_v<remove_cvref_t<T>, ck_tile::int8_t>)
{
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_i32x2_t;
__attribute__((address_space(3))) llvm_i32x2_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) llvm_i32x2_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_i32x2_t*>(in_ptr_);
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}
else
{
static_assert(false, "not implemented");
}
#undef __LDS_ADDR
}
#endif

View File

@@ -62,12 +62,12 @@ struct buffer_view<address_space_enum::generic,
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
BufferSizeType buffer_size,
T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
@@ -243,7 +243,7 @@ struct buffer_view<address_space_enum::global,
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
: p_data_{p_data},
buffer_size_{buffer_size / PackedSize},
cached_buf_res_{0},
@@ -251,7 +251,7 @@ struct buffer_view<address_space_enum::global,
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
BufferSizeType buffer_size,
T invalid_element_value)
: p_data_{p_data},
@@ -762,12 +762,12 @@ struct buffer_view<address_space_enum::lds,
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
BufferSizeType buffer_size,
T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
@@ -1121,12 +1121,12 @@ struct buffer_view<address_space_enum::vgpr,
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data, BufferSizeType buffer_size)
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data, BufferSizeType buffer_size)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{0}
{
}
CK_TILE_HOST_DEVICE constexpr buffer_view(T* p_data,
CK_TILE_HOST_DEVICE constexpr buffer_view(T* __restrict__ p_data,
BufferSizeType buffer_size,
T invalid_element_value)
: p_data_{p_data}, buffer_size_{buffer_size}, invalid_element_value_{invalid_element_value}
@@ -1253,7 +1253,7 @@ template <address_space_enum BufferAddressSpace,
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
typename T,
typename BufferSizeType>
CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* p, BufferSizeType buffer_size)
CK_TILE_HOST_DEVICE constexpr auto make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size)
{
return buffer_view<BufferAddressSpace, T, BufferSizeType, true, Coherence>{p, buffer_size};
}
@@ -1266,7 +1266,7 @@ template <address_space_enum BufferAddressSpace,
typename std::enable_if<std::is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto
make_buffer_view(T* p, BufferSizeType buffer_size, X invalid_element_value)
make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size, X invalid_element_value)
{
return buffer_view<BufferAddressSpace, T, BufferSizeType, false, Coherence>{
p, buffer_size, invalid_element_value};

View File

@@ -449,7 +449,7 @@ template <address_space_enum BufferAddressSpace = address_space_enum::generic,
amd_buffer_coherence_enum Coherence = amd_buffer_coherence_enum::coherence_default,
typename DataType,
typename... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* __restrict__ p,
const tensor_descriptor<Ts...>& desc)
{
auto buffer_view =
@@ -468,7 +468,7 @@ template <address_space_enum BufferAddressSpace = address_space_enum::generic,
index_t GuaranteedLastDimensionVectorStride = -1,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_view(DataType* p,
make_naive_tensor_view(DataType* __restrict__ p,
const tuple<Lengths...>& lengths,
const tuple<Strides...>& strides,
number<GuaranteedLastDimensionVectorLength> = number<-1>{},
@@ -491,7 +491,7 @@ template <address_space_enum BufferAddressSpace = address_space_enum::generic,
typename... Lengths,
index_t GuaranteedLastDimensionVectorLength = -1>
CK_TILE_HOST_DEVICE constexpr auto
make_naive_tensor_view_packed(DataType* p,
make_naive_tensor_view_packed(DataType* __restrict__ p,
const tuple<Lengths...>& lengths,
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
{

View File

@@ -1115,7 +1115,8 @@ struct FmhaBwdDQDKDVKernel
{i_n0, 0});
if constexpr(!kUseQrQtrDorPipeline)
{
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(smem_ptr,
q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
@@ -1131,7 +1132,6 @@ struct FmhaBwdDQDKDVKernel
kargs.scale,
rp_undrop,
scale_rp_undrop,
smem_ptr,
dropout);
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);
@@ -1139,7 +1139,8 @@ struct FmhaBwdDQDKDVKernel
}
else
{
FmhaPipeline{}(q_dram_window,
FmhaPipeline{}(smem_ptr,
q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
@@ -1160,7 +1161,6 @@ struct FmhaBwdDQDKDVKernel
kargs.scale,
rp_undrop,
scale_rp_undrop,
smem_ptr,
dropout);
}
}

View File

@@ -93,7 +93,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
typename BiasGradDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
operator()(void* smem_ptr,
const QDramBlockWindowTmp& q_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
@@ -109,7 +110,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
float scale,
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
FmhaDropout& dropout) const
{
static_assert(

View File

@@ -93,7 +93,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
typename BiasGradDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
operator()(void* smem_ptr,
const QDramBlockWindowTmp& q_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
@@ -109,7 +110,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
float scale,
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
FmhaDropout& dropout) const
{
static_assert(

View File

@@ -90,6 +90,53 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
else
return raw_lse;
};
template <typename... Ts>
CK_TILE_DEVICE auto operator()(void* smem_ptr, Ts&&... args) const
{
// LDS allocation
// cast to char* to do pointer arithmetic
const auto smem_ptr_ = reinterpret_cast<char*>(smem_ptr);
const auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr_);
const auto v_lds_ptr =
reinterpret_cast<VDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
const auto do_lds_ptr0 = reinterpret_cast<OGradDataType*>(smem_ptr_);
const auto do_lds_ptr1 = reinterpret_cast<OGradDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr0 = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr1 = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>());
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>());
const auto ds_lds_ptr = reinterpret_cast<GemmDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
return run(k_lds_ptr,
v_lds_ptr,
do_lds_ptr0,
do_lds_ptr1,
q_lds_ptr0,
q_lds_ptr1,
lse_lds_ptr,
d_lds_ptr,
ds_lds_ptr,
bias_lds_ptr,
std::forward<Ts>(args)...);
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
@@ -102,7 +149,17 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
typename QGradDramBlockWindowTmp,
typename BiasGradDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_DEVICE auto operator()( //
CK_TILE_DEVICE auto run( //
KDataType* __restrict__ k_lds_ptr,
VDataType* __restrict__ v_lds_ptr,
OGradDataType* __restrict__ do_lds_ptr0,
OGradDataType* __restrict__ do_lds_ptr1,
QDataType* __restrict__ q_lds_ptr0,
QDataType* __restrict__ q_lds_ptr1,
LSEDataType* __restrict__ lse_lds_ptr,
DDataType* __restrict__ d_lds_ptr,
GemmDataType* __restrict__ ds_lds_ptr,
BiasDataType* __restrict__ bias_lds_ptr,
const QDramBlockWindowTmp& q_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
@@ -119,7 +176,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
float scale,
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
FmhaDropout& dropout) const
{
static_assert(
@@ -184,40 +240,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
}
}
// LDS allocation
const auto smem_ptr_ =
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
const auto do_lds_ptr0 = reinterpret_cast<OGradDataType* __restrict__>(smem_ptr_);
const auto do_lds_ptr1 = reinterpret_cast<OGradDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr0 = reinterpret_cast<QDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>());
const auto q_lds_ptr1 = reinterpret_cast<QDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>());
const auto lse_lds_ptr = reinterpret_cast<LSEDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>());
const auto d_lds_ptr = reinterpret_cast<DDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>());
const auto ds_lds_ptr = reinterpret_cast<GemmDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeLSE<Problem>() + Policy::template GetSmemSizeD<Problem>());
const auto bias_lds_ptr = reinterpret_cast<BiasDataType* __restrict__>(ds_lds_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window =
@@ -453,13 +475,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
decltype(load_tile(d_dram_window)) d_block_tile;
index_t i_total_bodys = 0;
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
const bool is_even = (i_total_bodys % 2 == 0);
QDataType* const __restrict__ q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
QDataType* const __restrict__ q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
OGradDataType* const __restrict__ do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
OGradDataType* const __restrict__ do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
auto main_body_impl = [&](auto is_prologue_,
auto is_epilogue_,
QDataType* const __restrict__ q_lds_ptr_curr,
QDataType* const __restrict__ q_lds_ptr_next,
OGradDataType* const __restrict__ do_lds_ptr_curr,
OGradDataType* const __restrict__ do_lds_ptr_next) mutable {
constexpr bool is_prologue = is_prologue_.value;
constexpr bool is_epilogue = is_epilogue_.value;
static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
@@ -467,19 +488,19 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
if constexpr(is_prologue)
{
lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
q_lds_write_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_next);
async_load_tile(q_lds_write_window, q_dram_window);
move_tile_window(q_dram_window, {kM0, 0});
lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
do_lds_write_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_next);
async_load_tile(do_lds_write_window, do_dram_window);
move_tile_window(do_dram_window, {kM0, 0});
d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
}
if constexpr(is_epilogue)
{
@@ -611,8 +632,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = p[i_j_idx] >= 0;
ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
? (dp_acc[i_j_idx] - d[i_idx])
: d[i_idx]);
? (dp_acc[i_j_idx] - d[i_idx])
: d[i_idx]);
});
});
@@ -725,6 +746,20 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
}
move_tile_window(dq_dram_window, {kM0, 0});
}
};
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
const bool is_even = (i_total_bodys % 2 == 0);
const auto q_lds_ptr_curr = is_even ? q_lds_ptr1 : q_lds_ptr0;
const auto q_lds_ptr_next = is_even ? q_lds_ptr0 : q_lds_ptr1;
const auto do_lds_ptr_curr = is_even ? do_lds_ptr1 : do_lds_ptr0;
const auto do_lds_ptr_next = is_even ? do_lds_ptr0 : do_lds_ptr1;
main_body_impl(is_prologue_,
is_epilogue_,
q_lds_ptr_curr,
q_lds_ptr_next,
do_lds_ptr_curr,
do_lds_ptr_next);
i_total_bodys += 1;
};

View File

@@ -93,6 +93,42 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
return raw_lse;
};
template <typename... Ts>
CK_TILE_DEVICE auto operator()(void* smem_ptr, Ts&&... args) const
{
// LDS allocation
const auto smem_ptr_ =
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
const auto do_lds_ptr = reinterpret_cast<OGradDataType*>(smem_ptr_);
const auto q_lds_ptr = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>());
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>());
const auto ds_lds_ptr =
reinterpret_cast<GemmDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeV<Problem>());
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
return run(k_lds_ptr,
v_lds_ptr,
do_lds_ptr,
q_lds_ptr,
lse_lds_ptr,
d_lds_ptr,
ds_lds_ptr,
bias_lds_ptr,
std::forward<Ts>(args)...);
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
@@ -109,7 +145,15 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
typename KGradEpilogue,
typename VGradEpilogue,
typename PositionEncoding>
CK_TILE_DEVICE auto operator()( //
CK_TILE_DEVICE auto run( //
KDataType* __restrict__ k_lds_ptr,
VDataType* __restrict__ v_lds_ptr,
OGradDataType* __restrict__ do_lds_ptr,
QDataType* __restrict__ q_lds_ptr,
LSEDataType* __restrict__ lse_lds_ptr,
DDataType* __restrict__ d_lds_ptr,
GemmDataType* __restrict__ ds_lds_ptr,
BiasDataType* __restrict__ bias_lds_ptr,
const QDramBlockWindowTmp& q_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
@@ -131,7 +175,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
float scale,
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
FmhaDropout& dropout) const
{
static_assert(
@@ -181,29 +224,6 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
{seqlen_kv_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
// LDS allocation
const auto smem_ptr_ =
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
const auto do_lds_ptr = reinterpret_cast<OGradDataType*>(smem_ptr_);
const auto q_lds_ptr = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>());
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>());
const auto ds_lds_ptr =
reinterpret_cast<GemmDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeV<Problem>());
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window =

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -638,11 +638,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
typename LSEaccDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
operator()(const QDramBlockWindowTmp& __restrict__ q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& __restrict__ k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& __restrict__ v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& __restrict__ bias_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& __restrict__ lse_acc_dram_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
@@ -854,18 +854,10 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
__builtin_amdgcn_sched_barrier(0);
auto mainloop = [&](index_t cur_loop) {
const bool is_even_loop = (cur_loop % 2 == 0);
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
: static_cast<KDataType* __restrict__>(smem_ptrk1);
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
: static_cast<KDataType* __restrict__>(smem_ptrk0);
auto v_lds_write_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv1)
: static_cast<VDataType* __restrict__>(smem_ptrv0);
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
: static_cast<VDataType* __restrict__>(smem_ptrv1);
auto mainloop = [&](KDataType* __restrict__ k_lds_write_ptr,
KDataType* __restrict__ k_lds_read_ptr,
KDataType* __restrict__ v_lds_write_ptr,
KDataType* __restrict__ v_lds_read_ptr) {
// move V tile windows
block_sync_lds<k_lds_insts>();
move_tile_window(v_dram_window, {kN0, 0});
@@ -1110,11 +1102,20 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS_READ
});
};
}; // mainloop
do
{
mainloop(i_total_loops);
bool is_even_loop = i_total_loops % 2 == 0;
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
: static_cast<KDataType* __restrict__>(smem_ptrk1);
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
: static_cast<KDataType* __restrict__>(smem_ptrk0);
auto v_lds_write_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv1)
: static_cast<VDataType* __restrict__>(smem_ptrv0);
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
: static_cast<VDataType* __restrict__>(smem_ptrv1);
mainloop(k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr);
i_total_loops++;
} while(i_total_loops < num_total_loop);