merge develop and solve conflicts

This commit is contained in:
aska-0096
2025-08-22 03:15:51 +00:00
parent f21e916a8c
commit 7bdf6a7eef
6 changed files with 675 additions and 668 deletions

View File

@@ -1801,18 +1801,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem,
}
_Pragma("clang diagnostic push")
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"") template <
typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
constexpr index_t bytes = sizeof(T) * N;
@@ -1835,23 +1835,23 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
if constexpr(oob_conditional_check)
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
(as3_uint32_ptr)(smem),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
(as3_uint32_ptr)(smem),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
}
_Pragma("clang diagnostic pop")
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE
void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
"wrong! not implemented");
@@ -2787,11 +2787,10 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#endif
}
_Pragma("clang diagnostic push")
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
_Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
#if defined(__gfx950__)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
{
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
@@ -2801,8 +2800,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
(__attribute__((address_space(3))) llvm_fp16x4_t*)(in_ptr);
//reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
// reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
@@ -2810,8 +2809,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
(__attribute__((address_space(3))) llvm_bf16x4_t*)in_ptr;
//reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
// reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
@@ -2819,8 +2818,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
(__attribute__((address_space(3))) llvm_fp8x8_t*)in_ptr;
//reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
// reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}
else

View File

@@ -1571,18 +1571,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
}
_Pragma("clang diagnostic push")
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"") template <
typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
constexpr index_t bytes = sizeof(T) * N;
@@ -1605,23 +1605,23 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
if constexpr(oob_conditional_check)
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
(as3_uint32_ptr)(smem),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
(as3_uint32_ptr)(smem),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
}
_Pragma("clang diagnostic pop")
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE
void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
"wrong! not implemented");
@@ -2597,20 +2597,17 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
static_assert(bytes_per_thread == dword_bytes);
#endif
// LDS pointer must be attributed with the LDS address space.
as3_uint32_ptr lds_ptr =
(as3_uint32_ptr)(lds_base_ptr + lds_offset);
as3_uint32_ptr lds_ptr = (as3_uint32_ptr)(lds_base_ptr + lds_offset);
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0);
#endif
}
_Pragma("clang diagnostic push")
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
_Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
#if defined(__gfx950__)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
{
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
@@ -2620,8 +2617,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
(__attribute__((address_space(3))) llvm_fp16x4_t*)(in_ptr);
//reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
// reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
@@ -2629,8 +2626,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
(__attribute__((address_space(3))) llvm_bf16x4_t*)in_ptr;
//reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
// reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
@@ -2638,8 +2635,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
(__attribute__((address_space(3))) llvm_fp8x8_t*)in_ptr;
//reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
// reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}
else

View File

@@ -854,12 +854,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
__builtin_amdgcn_sched_barrier(0);
auto mainloop = [&] (index_t cur_loop,
auto mainloop = [&](index_t cur_loop,
KDataType* __restrict__ k_lds_write_ptr,
KDataType* __restrict__ k_lds_read_ptr,
KDataType* __restrict__ v_lds_write_ptr,
KDataType* __restrict__ v_lds_read_ptr) {
// move V tile windows
block_sync_lds<k_lds_insts>();
move_tile_window(v_dram_window, {kN0, 0});
@@ -1108,7 +1107,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
do
{
bool is_even_loop = i_total_loops % 2 == 0;
bool is_even_loop = i_total_loops % 2 == 0;
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
: static_cast<KDataType* __restrict__>(smem_ptrk1);
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
@@ -1117,7 +1116,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
: static_cast<VDataType* __restrict__>(smem_ptrv0);
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
: static_cast<VDataType* __restrict__>(smem_ptrv1);
mainloop(i_total_loops, k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr);
mainloop(
i_total_loops, k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr);
i_total_loops++;
} while(i_total_loops < num_total_loop);