merge develop and solve conflicts

This commit is contained in:
aska-0096
2025-08-22 03:15:51 +00:00
parent f21e916a8c
commit 7bdf6a7eef
6 changed files with 675 additions and 668 deletions

View File

@@ -40,8 +40,8 @@ template <typename FlatmmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int divisor = ck_tile::is_wave32() ? (FlatmmConfig::N_Warp_Tile == 32 ? 1 : 2)
: (FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4);

View File

@@ -11,17 +11,17 @@ template <ck_tile::index_t NDimSpatial,
typename WeiLayout,
typename OutLayout>
float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
int n_warmup,
int n_repeat)
int n_warmup,
int n_repeat)
{
float ave_time = grouped_conv_bwd_data<NDimSpatial,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = args.GetFlops();
@@ -124,11 +124,11 @@ int run_grouped_conv_bwd_data_example_with_layouts(
output_dev_buf.ToDevice(output.data());
ck_tile::GroupedConvBwdDataHostArgs args(conv_param,
input_dev_buf.GetDeviceBuffer(),
weight_dev_buf.GetDeviceBuffer(),
{},
output_dev_buf.GetDeviceBuffer(),
kbatch);
input_dev_buf.GetDeviceBuffer(),
weight_dev_buf.GetDeviceBuffer(),
{},
output_dev_buf.GetDeviceBuffer(),
kbatch);
std::cout << "Run Grouped Conv Bwd Data kernel" << std::endl;
std::cout << "input: " << input.mDesc << std::endl;
@@ -136,13 +136,13 @@ int run_grouped_conv_bwd_data_example_with_layouts(
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_bwd_data<NDimSpatial,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
input_dev_buf.FromDevice(input.data());
bool pass = true;
@@ -152,17 +152,15 @@ int run_grouped_conv_bwd_data_example_with_layouts(
ck_tile::HostTensor<InDataType> input_host_ref(in_g_n_c_wis_desc);
input_host_ref.SetZero();
ck_tile::
reference_grouped_conv_bwd_data<NDimSpatial, InDataType, WeiDataType, OutDataType>(
input_host_ref,
weight,
output,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_);
const ck_tile::index_t GemmK =
weight.get_element_size() / (conv_param.G_ * conv_param.K_);
ck_tile::reference_grouped_conv_bwd_data<NDimSpatial, InDataType, WeiDataType, OutDataType>(
input_host_ref,
weight,
output,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_);
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
const float max_accumulated_value =
*std::max_element(input_host_ref.mData.begin(), input_host_ref.mData.end());
const auto rtol_atol =

View File

@@ -1801,18 +1801,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem,
}
_Pragma("clang diagnostic push")
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"") template <
typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
constexpr index_t bytes = sizeof(T) * N;
@@ -1835,23 +1835,23 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
if constexpr(oob_conditional_check)
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
(as3_uint32_ptr)(smem),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
(as3_uint32_ptr)(smem),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
}
_Pragma("clang diagnostic pop")
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE
void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
"wrong! not implemented");
@@ -2787,11 +2787,10 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#endif
}
_Pragma("clang diagnostic push")
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
_Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
#if defined(__gfx950__)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
{
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
@@ -2801,8 +2800,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
(__attribute__((address_space(3))) llvm_fp16x4_t*)(in_ptr);
//reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
// reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
@@ -2810,8 +2809,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
(__attribute__((address_space(3))) llvm_bf16x4_t*)in_ptr;
//reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
// reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
@@ -2819,8 +2818,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
(__attribute__((address_space(3))) llvm_fp8x8_t*)in_ptr;
//reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
// reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}
else

View File

@@ -1571,18 +1571,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
}
_Pragma("clang diagnostic push")
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"") template <
typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
constexpr index_t bytes = sizeof(T) * N;
@@ -1605,23 +1605,23 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
if constexpr(oob_conditional_check)
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
llvm_amdgcn_raw_buffer_load_lds(
src_wave_buffer_resource,
(as3_uint32_ptr)(smem),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
(as3_uint32_ptr)(smem),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
}
_Pragma("clang diagnostic pop")
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
template <index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
CK_TILE_DEVICE
void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
int32x4_t dst_wave_buffer_resource,
index_t dst_thread_addr_offset,
index_t dst_wave_addr_offset)
{
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
"wrong! not implemented");
@@ -2597,20 +2597,17 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
static_assert(bytes_per_thread == dword_bytes);
#endif
// LDS pointer must be attributed with the LDS address space.
as3_uint32_ptr lds_ptr =
(as3_uint32_ptr)(lds_base_ptr + lds_offset);
as3_uint32_ptr lds_ptr = (as3_uint32_ptr)(lds_base_ptr + lds_offset);
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0);
#endif
}
_Pragma("clang diagnostic push")
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
_Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
#if defined(__gfx950__)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
template <typename T, index_t N, address_space_enum BufferAddressSpace>
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
{
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
@@ -2620,8 +2617,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
(__attribute__((address_space(3))) llvm_fp16x4_t*)(in_ptr);
//reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
// reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
@@ -2629,8 +2626,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
(__attribute__((address_space(3))) llvm_bf16x4_t*)in_ptr;
//reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
// reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
@@ -2638,8 +2635,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
(__attribute__((address_space(3))) llvm_fp8x8_t*)in_ptr;
//reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
// reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
// reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
}
else

View File

@@ -854,12 +854,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
__builtin_amdgcn_sched_barrier(0);
auto mainloop = [&] (index_t cur_loop,
auto mainloop = [&](index_t cur_loop,
KDataType* __restrict__ k_lds_write_ptr,
KDataType* __restrict__ k_lds_read_ptr,
KDataType* __restrict__ v_lds_write_ptr,
KDataType* __restrict__ v_lds_read_ptr) {
// move V tile windows
block_sync_lds<k_lds_insts>();
move_tile_window(v_dram_window, {kN0, 0});
@@ -1108,7 +1107,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
do
{
bool is_even_loop = i_total_loops % 2 == 0;
bool is_even_loop = i_total_loops % 2 == 0;
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
: static_cast<KDataType* __restrict__>(smem_ptrk1);
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
@@ -1117,7 +1116,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
: static_cast<VDataType* __restrict__>(smem_ptrv0);
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
: static_cast<VDataType* __restrict__>(smem_ptrv1);
mainloop(i_total_loops, k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr);
mainloop(
i_total_loops, k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr);
i_total_loops++;
} while(i_total_loops < num_total_loop);