mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 05:28:37 +00:00
merge develop and solve conflicts
This commit is contained in:
@@ -40,8 +40,8 @@ template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
int divisor = ck_tile::is_wave32() ? (FlatmmConfig::N_Warp_Tile == 32 ? 1 : 2)
|
||||
: (FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4);
|
||||
|
||||
@@ -11,17 +11,17 @@ template <ck_tile::index_t NDimSpatial,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
float invoke_grouped_conv_bwd_data(ck_tile::GroupedConvBwdDataHostArgs& args,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = grouped_conv_bwd_data<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = args.GetFlops();
|
||||
@@ -124,11 +124,11 @@ int run_grouped_conv_bwd_data_example_with_layouts(
|
||||
output_dev_buf.ToDevice(output.data());
|
||||
|
||||
ck_tile::GroupedConvBwdDataHostArgs args(conv_param,
|
||||
input_dev_buf.GetDeviceBuffer(),
|
||||
weight_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
output_dev_buf.GetDeviceBuffer(),
|
||||
kbatch);
|
||||
input_dev_buf.GetDeviceBuffer(),
|
||||
weight_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
output_dev_buf.GetDeviceBuffer(),
|
||||
kbatch);
|
||||
|
||||
std::cout << "Run Grouped Conv Bwd Data kernel" << std::endl;
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
@@ -136,13 +136,13 @@ int run_grouped_conv_bwd_data_example_with_layouts(
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
invoke_grouped_conv_bwd_data<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(args, n_warmup, n_repeat);
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(args, n_warmup, n_repeat);
|
||||
|
||||
input_dev_buf.FromDevice(input.data());
|
||||
bool pass = true;
|
||||
@@ -152,17 +152,15 @@ int run_grouped_conv_bwd_data_example_with_layouts(
|
||||
ck_tile::HostTensor<InDataType> input_host_ref(in_g_n_c_wis_desc);
|
||||
input_host_ref.SetZero();
|
||||
|
||||
ck_tile::
|
||||
reference_grouped_conv_bwd_data<NDimSpatial, InDataType, WeiDataType, OutDataType>(
|
||||
input_host_ref,
|
||||
weight,
|
||||
output,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_);
|
||||
const ck_tile::index_t GemmK =
|
||||
weight.get_element_size() / (conv_param.G_ * conv_param.K_);
|
||||
ck_tile::reference_grouped_conv_bwd_data<NDimSpatial, InDataType, WeiDataType, OutDataType>(
|
||||
input_host_ref,
|
||||
weight,
|
||||
output,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_);
|
||||
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(input_host_ref.mData.begin(), input_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
|
||||
@@ -1801,18 +1801,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(CK_TILE_LDS_ADDR T* smem,
|
||||
}
|
||||
|
||||
_Pragma("clang diagnostic push")
|
||||
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
index_t src_immediate_addr_offset = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"") template <
|
||||
typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
index_t src_immediate_addr_offset = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
|
||||
@@ -1835,23 +1835,23 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
if constexpr(oob_conditional_check)
|
||||
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
(as3_uint32_ptr)(smem),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
/*src_immediate_addr_offset*/ 0,
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
(as3_uint32_ptr)(smem),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
/*src_immediate_addr_offset*/ 0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
_Pragma("clang diagnostic pop")
|
||||
|
||||
template <index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
|
||||
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
template <index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
|
||||
CK_TILE_DEVICE
|
||||
void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
{
|
||||
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
|
||||
"wrong! not implemented");
|
||||
@@ -2787,11 +2787,10 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
|
||||
#endif
|
||||
}
|
||||
|
||||
_Pragma("clang diagnostic push")
|
||||
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
|
||||
_Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
|
||||
#if defined(__gfx950__)
|
||||
template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
{
|
||||
|
||||
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
|
||||
@@ -2801,8 +2800,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
|
||||
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
|
||||
(__attribute__((address_space(3))) llvm_fp16x4_t*)(in_ptr);
|
||||
//reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
|
||||
// reinterpret_cast<uintptr_t>(in_ptr));
|
||||
// reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
|
||||
// reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
|
||||
@@ -2810,8 +2809,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
|
||||
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
|
||||
(__attribute__((address_space(3))) llvm_bf16x4_t*)in_ptr;
|
||||
//reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
|
||||
// reinterpret_cast<uintptr_t>(in_ptr));
|
||||
// reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
|
||||
// reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
|
||||
@@ -2819,8 +2818,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
|
||||
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
|
||||
(__attribute__((address_space(3))) llvm_fp8x8_t*)in_ptr;
|
||||
//reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
|
||||
// reinterpret_cast<uintptr_t>(in_ptr));
|
||||
// reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
|
||||
// reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
|
||||
}
|
||||
else
|
||||
|
||||
@@ -1571,18 +1571,18 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
|
||||
}
|
||||
|
||||
_Pragma("clang diagnostic push")
|
||||
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
|
||||
template <typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
index_t src_immediate_addr_offset = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"") template <
|
||||
typename T,
|
||||
index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
int32x4_t src_wave_buffer_resource,
|
||||
index_t src_thread_addr_offset,
|
||||
index_t src_wave_addr_offset,
|
||||
index_t src_immediate_addr_offset = 0,
|
||||
index_t flag = 0,
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
constexpr index_t bytes = sizeof(T) * N;
|
||||
|
||||
@@ -1605,23 +1605,23 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
|
||||
if constexpr(oob_conditional_check)
|
||||
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_wave_buffer_resource,
|
||||
(as3_uint32_ptr)(smem),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
/*src_immediate_addr_offset*/ 0,
|
||||
static_cast<index_t>(coherence));
|
||||
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
|
||||
(as3_uint32_ptr)(smem),
|
||||
bytes,
|
||||
v_offset,
|
||||
src_wave_addr_offset,
|
||||
/*src_immediate_addr_offset*/ 0,
|
||||
static_cast<index_t>(coherence));
|
||||
}
|
||||
_Pragma("clang diagnostic pop")
|
||||
|
||||
template <index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
|
||||
CK_TILE_DEVICE void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
template <index_t N,
|
||||
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default>
|
||||
CK_TILE_DEVICE
|
||||
void amd_buffer_store_impl_with_bytes(const thread_buffer<int8_t, N> src_thread_data,
|
||||
int32x4_t dst_wave_buffer_resource,
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
{
|
||||
static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64,
|
||||
"wrong! not implemented");
|
||||
@@ -2597,20 +2597,17 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
static_assert(bytes_per_thread == dword_bytes);
|
||||
#endif
|
||||
// LDS pointer must be attributed with the LDS address space.
|
||||
as3_uint32_ptr lds_ptr =
|
||||
(as3_uint32_ptr)(lds_base_ptr + lds_offset);
|
||||
as3_uint32_ptr lds_ptr = (as3_uint32_ptr)(lds_base_ptr + lds_offset);
|
||||
|
||||
llvm_amdgcn_raw_buffer_load_lds(
|
||||
src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
_Pragma("clang diagnostic push")
|
||||
_Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
|
||||
_Pragma("clang diagnostic push") _Pragma("clang diagnostic ignored \"-Wno-old-style-cast\"")
|
||||
#if defined(__gfx950__)
|
||||
template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
{
|
||||
|
||||
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
|
||||
@@ -2620,8 +2617,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__fp16)))) __fp16 llvm_fp16x4_t;
|
||||
__attribute__((address_space(3))) llvm_fp16x4_t* lds_ptr =
|
||||
(__attribute__((address_space(3))) llvm_fp16x4_t*)(in_ptr);
|
||||
//reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
|
||||
// reinterpret_cast<uintptr_t>(in_ptr));
|
||||
// reinterpret_cast<__attribute__((address_space(3))) llvm_fp16x4_t*>(
|
||||
// reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4f16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
|
||||
@@ -2629,8 +2626,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
|
||||
__attribute__((address_space(3))) llvm_bf16x4_t* lds_ptr =
|
||||
(__attribute__((address_space(3))) llvm_bf16x4_t*)in_ptr;
|
||||
//reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
|
||||
// reinterpret_cast<uintptr_t>(in_ptr));
|
||||
// reinterpret_cast<__attribute__((address_space(3))) llvm_bf16x4_t*>(
|
||||
// reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
|
||||
}
|
||||
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t>)
|
||||
@@ -2638,8 +2635,8 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
typedef __attribute__((__vector_size__(2 * sizeof(index_t)))) index_t llvm_fp8x8_t;
|
||||
__attribute__((address_space(3))) llvm_fp8x8_t* lds_ptr =
|
||||
(__attribute__((address_space(3))) llvm_fp8x8_t*)in_ptr;
|
||||
//reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
|
||||
// reinterpret_cast<uintptr_t>(in_ptr));
|
||||
// reinterpret_cast<__attribute__((address_space(3))) llvm_fp8x8_t*>(
|
||||
// reinterpret_cast<uintptr_t>(in_ptr));
|
||||
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr8_b64_v2i32(lds_ptr));
|
||||
}
|
||||
else
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -854,12 +854,11 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
auto mainloop = [&] (index_t cur_loop,
|
||||
auto mainloop = [&](index_t cur_loop,
|
||||
KDataType* __restrict__ k_lds_write_ptr,
|
||||
KDataType* __restrict__ k_lds_read_ptr,
|
||||
KDataType* __restrict__ v_lds_write_ptr,
|
||||
KDataType* __restrict__ v_lds_read_ptr) {
|
||||
|
||||
// move V tile windows
|
||||
block_sync_lds<k_lds_insts>();
|
||||
move_tile_window(v_dram_window, {kN0, 0});
|
||||
@@ -1108,7 +1107,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
|
||||
do
|
||||
{
|
||||
bool is_even_loop = i_total_loops % 2 == 0;
|
||||
bool is_even_loop = i_total_loops % 2 == 0;
|
||||
auto k_lds_write_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk0)
|
||||
: static_cast<KDataType* __restrict__>(smem_ptrk1);
|
||||
auto k_lds_read_ptr = is_even_loop ? static_cast<KDataType* __restrict__>(smem_ptrk1)
|
||||
@@ -1117,7 +1116,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
: static_cast<VDataType* __restrict__>(smem_ptrv0);
|
||||
auto v_lds_read_ptr = is_even_loop ? static_cast<VDataType* __restrict__>(smem_ptrv0)
|
||||
: static_cast<VDataType* __restrict__>(smem_ptrv1);
|
||||
mainloop(i_total_loops, k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr);
|
||||
mainloop(
|
||||
i_total_loops, k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr);
|
||||
i_total_loops++;
|
||||
} while(i_total_loops < num_total_loop);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user