mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
refactor
This commit is contained in:
@@ -344,64 +344,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register
|
||||
constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc =
|
||||
make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
|
||||
|
||||
// output tensor descriptor in register, src of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc =
|
||||
out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old(
|
||||
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
|
||||
|
||||
// output memory layout descriptor in device memory, dst of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc =
|
||||
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}, Number<K2>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
// output merged global tensor descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5),
|
||||
Sequence<3>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 4, 5>{},
|
||||
Sequence<2>{});
|
||||
|
||||
// origin of dst in device memory
|
||||
Float* p_out_thread_on_global =
|
||||
p_out_global +
|
||||
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v2r1<
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 8, 1>::type,
|
||||
arithmetic_sequence_gen<0, 8, 1>::type,
|
||||
7,
|
||||
7,
|
||||
1,
|
||||
1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0})
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#else
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
@@ -449,7 +391,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
b_thread_data_on_global,
|
||||
0})
|
||||
.template Run_amd_experiment<Float, 0, 2>(p_out_thread, p_out_global);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -369,31 +369,37 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
|
||||
static_assert(K % (K1 * K2) == 0, "wrong!");
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t K0 = K / K1;
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register
|
||||
constexpr auto out_k0_k1_k2_n1_n0_ho_wo_n2_thread_desc =
|
||||
make_native_tensor_descriptor_packed(
|
||||
Sequence<KPerBlock / (K1 * K2), 1, K2, N1, 1, 1, 1, N2>{});
|
||||
// output memory layout descriptor in register, src of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{});
|
||||
|
||||
// output tensor descriptor in register, src of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc =
|
||||
reorder_tensor_descriptor_given_upper2lower(out_k0_k1_k2_n1_n0_ho_wo_n2_thread_desc,
|
||||
Sequence<4, 3, 7, 0, 1, 2, 5, 6>{});
|
||||
// output memory layout descriptor in device memory
|
||||
constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc_old =
|
||||
OutGlobalDesc::Fold(I1, Number<K1>{}).Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// output memory layout descriptor in device memory, dst of threadwise copy
|
||||
constexpr auto out_n0_n1_n2_k0_k1_k2_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(Unmerge<Sequence<N / (N1 * N2), N1, N2>>{},
|
||||
Unmerge<Sequence<K / (K1 * K2), K1, K2>>{},
|
||||
PassThrough<Ho>{},
|
||||
PassThrough<Wo>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}, Sequence<6>{}, Sequence<7>{}));
|
||||
constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = make_native_tensor_descriptor(
|
||||
out_n0_n1_n2_k0_k1_ho_wo_global_desc_old.GetLengths(),
|
||||
out_n0_n1_n2_k0_k1_ho_wo_global_desc_old.GetStrides());
|
||||
|
||||
// output merged global tensor descriptor, dst of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_global_desc = transform_tensor_descriptor(
|
||||
out_n0_n1_n2_k0_k1_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<K0>{},
|
||||
PassThrough<K1>{},
|
||||
PassThrough<N1>{},
|
||||
Merge<Sequence<N0, Ho, Wo>>{},
|
||||
PassThrough<N2>{}),
|
||||
make_tuple(Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 5, 6>{},
|
||||
Sequence<2>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
@@ -406,41 +412,20 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
// output merged global tensor descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_n0_n1_n2_k_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(Unmerge<Sequence<N / (N1 * N2), N1, N2>>{},
|
||||
PassThrough<K>{},
|
||||
PassThrough<Ho>{},
|
||||
PassThrough<Wo>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}));
|
||||
|
||||
constexpr auto out_k_n1_b_n2_global_desc = transform_tensor_descriptor(
|
||||
out_n0_n1_n2_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<K>{},
|
||||
PassThrough<N1>{},
|
||||
Merge<Sequence<N0, Ho, Wo>>{},
|
||||
PassThrough<N2>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<1>{}, Sequence<0, 4, 5>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// origin of dst in device memory
|
||||
Float* p_out_thread_on_global =
|
||||
p_out_global +
|
||||
out_k_n1_b_n2_global_desc.CalculateOffset(
|
||||
{k_thread_data_on_global, 0, b_thread_data_on_global, 0});
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_ho_wo_global_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_ho_wo_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 8, 1>::type,
|
||||
7,
|
||||
1,
|
||||
1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0})
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(out_k0_k1_n1_b_n2_thread_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_global_desc),
|
||||
decltype(
|
||||
out_k0_k1_n1_b_n2_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
3,
|
||||
1,
|
||||
1>({0, 0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
0,
|
||||
b_thread_data_on_global,
|
||||
0})
|
||||
.template Run_amd_experiment<Float, 0, 2>(p_out_thread, p_out_global);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -737,13 +737,23 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
template <typename TData>
|
||||
__device__ void RunLoadRegisterBuffer(const TData* p_src, TData* p_buffer) const
|
||||
{
|
||||
#if 0
|
||||
mThreadwiseLoad.Run(p_src, p_buffer);
|
||||
#else
|
||||
// hardcoded: global to register
|
||||
mThreadwiseLoad.template Run_amd_experiment<TData, 2, 0>(p_src, p_buffer);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename TData>
|
||||
__device__ void RunStoreRegisterBuffer(const TData* p_buffer, TData* p_dst) const
|
||||
{
|
||||
#if 0
|
||||
mThreadwiseStore.Run(p_buffer, p_dst);
|
||||
#else
|
||||
// hardcoded: register to LDS
|
||||
mThreadwiseStore.template Run_amd_experiment<TData, 0, 1>(p_buffer, p_dst);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename TData>
|
||||
@@ -751,8 +761,8 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
{
|
||||
TData p_buffer[GetRegisterBufferSize()];
|
||||
|
||||
mThreadwiseLoad.Run(p_src, p_buffer);
|
||||
mThreadwiseStore.Run(p_buffer, p_dst);
|
||||
RunLoadRegisterBuffer(p_src, p_buffer);
|
||||
RunStoreRegisterBuffer(p_buffer, p_dst);
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
|
||||
@@ -1276,6 +1276,107 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
});
|
||||
}
|
||||
|
||||
// memory-space
|
||||
// 0: VGPR
|
||||
// 1: LDS
|
||||
// 2: global-memory
|
||||
template <class TData, index_t SrcMemorySpace, index_t DstMemorySpace>
|
||||
__device__ void Run_amd_experiment(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
using src_vector_t = typename vector_type<TData, SrcDataPerAccess>::MemoryType;
|
||||
using dst_vector_t = typename vector_type<TData, DstDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
|
||||
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
|
||||
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
|
||||
|
||||
constexpr auto long_vector_size = Number<math::lcm(SrcDataPerAccess, DstDataPerAccess)>{};
|
||||
|
||||
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
|
||||
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
|
||||
|
||||
ford<decltype(long_vector_access_lengths), DimAccessOrder>{}([&](
|
||||
auto long_vector_access_id) {
|
||||
|
||||
// data id w.r.t slicing-window
|
||||
auto long_vector_data_begin_id = long_vector_access_id;
|
||||
long_vector_data_begin_id(vector_access_dim) =
|
||||
long_vector_size * long_vector_access_id[vector_access_dim];
|
||||
|
||||
// buffer to hold a long-vector
|
||||
TData p_long_vector[long_vector_size];
|
||||
|
||||
// set 0
|
||||
for(index_t i = 0; i < long_vector_size; ++i)
|
||||
{
|
||||
p_long_vector[i] = 0;
|
||||
}
|
||||
|
||||
// load data from src to the long-vector buffer
|
||||
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
|
||||
{
|
||||
auto scalar_id = make_zero_array<index_t, nDim>();
|
||||
scalar_id(vector_access_dim) = i * src_data_per_access;
|
||||
|
||||
const auto src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id);
|
||||
|
||||
// check for padding
|
||||
// TODO: still kind of messy
|
||||
if(!src_coord.IsAnyLevelIndexInPaddingArea())
|
||||
{
|
||||
const index_t src_offset = src_coord.GetOffset();
|
||||
|
||||
const index_t buffer_offset = i * src_data_per_access;
|
||||
|
||||
static_if<SrcMemorySpace == 2>{}([&](auto) {
|
||||
#if 0 // source code
|
||||
*reinterpret_cast<src_vector_t*>(&p_long_vector[buffer_offset]) =
|
||||
*reinterpret_cast<const src_vector_t*>(&p_src[src_offset]);
|
||||
#elif 1 // inline asm using buffer_load
|
||||
*reinterpret_cast<src_vector_t*>(&p_long_vector[buffer_offset]) =
|
||||
__buffer_load<TData, SrcDataPerAccess>(
|
||||
p_src, static_cast<uint32_t>(src_offset), static_cast<uint32_t>(0));
|
||||
#endif
|
||||
}).Else([&](auto) {
|
||||
// src can be all kinds of memory-space.
|
||||
*reinterpret_cast<src_vector_t*>(&p_long_vector[buffer_offset]) =
|
||||
*reinterpret_cast<const src_vector_t*>(&p_src[src_offset]);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// store data from the long-vector buffer to dst
|
||||
for(index_t i = 0; i < long_vector_size / dst_data_per_access; ++i)
|
||||
{
|
||||
auto scalar_id = make_zero_array<index_t, nDim>();
|
||||
scalar_id(vector_access_dim) = i * dst_data_per_access;
|
||||
|
||||
const index_t buffer_offset = i * dst_data_per_access;
|
||||
|
||||
const index_t dst_offset =
|
||||
(mDstSliceOrigin + (long_vector_data_begin_id + scalar_id)).GetOffset();
|
||||
|
||||
static_if<DstMemorySpace == 2>{}([&](auto) {
|
||||
#if 0 // source code
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<dst_vector_t*>(&p_long_vector[buffer_offset]);
|
||||
#elif 1 // inline asm using buffer_store
|
||||
__buffer_store<TData, DstDataPerAccess>(
|
||||
*reinterpret_cast<dst_vector_t*>(&p_long_vector[buffer_offset]),
|
||||
p_dst,
|
||||
dst_offset,
|
||||
0);
|
||||
#endif
|
||||
}).Else([&](auto) {
|
||||
// dst can be all kinds of memory-space
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<dst_vector_t*>(&p_long_vector[buffer_offset]);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void MoveSrcSliceWindow(const T& step_sizes_,
|
||||
integral_constant<bool, PositiveDirection>)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "hip/hip_fp16.h"
|
||||
|
||||
#define CK_DEVICE_BACKEND_AMD 1
|
||||
#define CK_USE_UNSIGNED_INDEX_TYPE 1
|
||||
#define CK_USE_UNSIGNED_INDEX_TYPE 0
|
||||
#define CK_USE_AMD_INLINE_ASM 1
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
|
||||
|
||||
@@ -51,7 +51,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded(InDesc,
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
|
||||
@@ -76,12 +76,12 @@ int main(int argc, char* argv[])
|
||||
|
||||
#if 0
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 34;
|
||||
constexpr index_t WI = 34;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
@@ -103,7 +103,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 1x1 filter, 8x8 image
|
||||
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
|
||||
constexpr index_t N = 64;
|
||||
@@ -341,7 +341,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<3, 0>;
|
||||
using RightPads = Sequence<3, 0>;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 1x7 filter, 0x3 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 128;
|
||||
@@ -434,7 +434,7 @@ int main(int argc, char* argv[])
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(
|
||||
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
|
||||
Reference in New Issue
Block a user