mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
refactor
This commit is contained in:
@@ -66,14 +66,20 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<1, 4>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<8, 32>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [E, K]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [E, K]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [K, E]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 4;
|
||||
@@ -114,10 +120,16 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>{};
|
||||
|
||||
|
||||
@@ -410,7 +410,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t N = 8;
|
||||
constexpr index_t C = 8;
|
||||
constexpr index_t C = 16;
|
||||
constexpr index_t HI = 3;
|
||||
constexpr index_t WI = 18;
|
||||
constexpr index_t K = 128;
|
||||
|
||||
@@ -371,7 +371,7 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
mThreadSrcOriginalMultiId[idim_original] += StepSize;
|
||||
|
||||
mThreadSrcPartialOffsets[idim] += StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
}).Else([&](auto) {
|
||||
}).Else([&](auto fwd) {
|
||||
mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
|
||||
mThreadSrcOriginalMultiId[idim_original] -= StepSize;
|
||||
|
||||
@@ -16,8 +16,8 @@ struct BlockwiseTensorSliceReorderCopy_v3
|
||||
{
|
||||
static constexpr index_t nDim = SrcLengths::GetSize();
|
||||
|
||||
index_t mSrcMyThreadOffset;
|
||||
index_t mDstMyThreadOffset;
|
||||
index_t mThreadSrcOffset;
|
||||
index_t mThreadDstOffset;
|
||||
|
||||
__device__
|
||||
BlockwiseTensorSliceReorderCopy_v3(Array<index_t, nDim> src_block_data_multi_id_begin,
|
||||
@@ -128,11 +128,36 @@ struct BlockwiseTensorSliceReorderCopy_v3
|
||||
// optimized away???
|
||||
const auto dst_data_multi_id = reorder_array_given_new2old(src_data_multi_id, map_dst2src);
|
||||
|
||||
mSrcMyThreadOffset =
|
||||
mThreadSrcOffset =
|
||||
src_desc.GetOffsetFromMultiIndex(src_data_multi_id + src_block_data_multi_id_begin);
|
||||
|
||||
mDstMyThreadOffset =
|
||||
mThreadDstOffset =
|
||||
dst_desc.GetOffsetFromMultiIndex(dst_data_multi_id + dst_block_data_multi_id_begin);
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(thread_cluster_desc, "thread_cluster_desc: ");
|
||||
}
|
||||
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("id %5u %5u: "
|
||||
"thread_multi_id: %u %u, "
|
||||
"src_block_data_multi_id_begin: %u %u, "
|
||||
"src_data_multi_id: %u %u, "
|
||||
"mThreadSrcOffset %u, mThreadDstOffset %u \n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
thread_multi_id[0],
|
||||
thread_multi_id[1],
|
||||
src_block_data_multi_id_begin[0],
|
||||
src_block_data_multi_id_begin[1],
|
||||
src_data_multi_id[0],
|
||||
src_data_multi_id[1],
|
||||
mThreadSrcOffset,
|
||||
mThreadDstOffset);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterClipboardSize()
|
||||
@@ -185,7 +210,7 @@ struct BlockwiseTensorSliceReorderCopy_v3
|
||||
thread_tensor_desc.GetOffsetFromMultiIndex(clipboard_data_multi_id);
|
||||
|
||||
threadwise_tensor_slice_copy(SrcDesc{},
|
||||
p_src + src_offset + mSrcMyThreadOffset,
|
||||
p_src + src_offset + mThreadSrcOffset,
|
||||
thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
thread_sub_tensor_lengths,
|
||||
@@ -232,7 +257,7 @@ struct BlockwiseTensorSliceReorderCopy_v3
|
||||
p_clipboard + clipboard_offset,
|
||||
DstDesc{},
|
||||
p_dst + dst_offset +
|
||||
mDstMyThreadOffset,
|
||||
mThreadDstOffset,
|
||||
thread_sub_tensor_lengths,
|
||||
MapDst2Src{});
|
||||
#else
|
||||
@@ -240,7 +265,7 @@ struct BlockwiseTensorSliceReorderCopy_v3
|
||||
p_clipboard + clipboard_offset,
|
||||
DstDesc{},
|
||||
p_dst + dst_offset +
|
||||
mDstMyThreadOffset,
|
||||
mThreadDstOffset,
|
||||
thread_sub_tensor_lengths,
|
||||
MapDst2Src{},
|
||||
Number<DstDataPerWrite>{});
|
||||
@@ -255,4 +280,17 @@ struct BlockwiseTensorSliceReorderCopy_v3
|
||||
RunLoadRegisterClipboard(p_src, p_clipboard);
|
||||
RunStoreRegisterClipboard(p_clipboard, p_dst);
|
||||
}
|
||||
|
||||
// this function doesn't do santiy check on whether the slicing window is out of the boundary
|
||||
// of the tensor being sliced
|
||||
template <index_t IDim_, index_t StepSize, bool PositiveDirection>
|
||||
__device__ void MoveSlicingWindowOnSourceTensor(
|
||||
Number<IDim_>, Number<StepSize>, integral_constant<bool, PositiveDirection> direction)
|
||||
{
|
||||
constexpr auto IDim = Number<IDim_>{};
|
||||
|
||||
static_if<PositiveDirection>{}([&](auto fwd) {
|
||||
mThreadSrcOffset += StepSize * fwd(SrcDesc{}).GetStride(IDim);
|
||||
}).Else([&](auto fwd) { mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim); });
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,41 +1,33 @@
|
||||
#pragma once
|
||||
#include "Sequence.hip.hpp"
|
||||
|
||||
template <index_t RemainDim>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class RemainLengths>
|
||||
struct static_ford_impl
|
||||
{
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// CurrentMultiIndex: Sequence<...>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex, class RemainLengths>
|
||||
__host__ __device__ void operator()(F f, CurrentMultiIndex, RemainLengths) const
|
||||
template <class F, class CurrentMultiIndex>
|
||||
__host__ __device__ void operator()(F f, CurrentMultiIndex) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() == RemainDim, "wrong!");
|
||||
static_assert(RemainDim > 1, "wrong!");
|
||||
static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here");
|
||||
|
||||
constexpr auto next_length = RemainLengths{}.Front();
|
||||
|
||||
static_for<0, next_length, 1>{}([=](auto I) {
|
||||
static_ford_impl<RemainDim - 1>{}(
|
||||
f, CurrentMultiIndex{}.PushBack(I), RemainLengths{}.PopFront());
|
||||
static_for<0, RemainLengths::Front(), 1>{}([=](auto I) {
|
||||
static_ford_impl<decltype(RemainLengths::PopFront())>{}(f,
|
||||
CurrentMultiIndex::PushBack(I));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_ford_impl<1>
|
||||
struct static_ford_impl<Sequence<>>
|
||||
{
|
||||
// F signature: F(Sequence<Is...> multi_id)
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// CurrentMultiIndex: Sequence<...>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex, class RemainLengths>
|
||||
__host__ __device__ void operator()(F f, CurrentMultiIndex, RemainLengths) const
|
||||
template <class F, class CurrentMultiIndex>
|
||||
__host__ __device__ void operator()(F f, CurrentMultiIndex) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() == 1, "wrong!");
|
||||
|
||||
constexpr index_t last_length = RemainLengths{}.Front();
|
||||
|
||||
static_for<0, last_length, 1>{}([=](auto I) { f(CurrentMultiIndex{}.PushBack(I)); });
|
||||
f(CurrentMultiIndex{});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -43,16 +35,13 @@ struct static_ford_impl<1>
|
||||
template <class Lengths>
|
||||
struct static_ford
|
||||
{
|
||||
// F signature: F(Sequence<Is...> multi_id)
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
{
|
||||
constexpr index_t first_length = Lengths{}.Front();
|
||||
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
|
||||
|
||||
static_for<0, first_length, 1>{}([=](auto I) {
|
||||
static_ford_impl<Lengths::GetSize() - 1>{}(
|
||||
f, Sequence<I.Get()>{}, Lengths{}.PopFront());
|
||||
});
|
||||
static_ford_impl<Lengths>{}(f, Sequence<>{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -30,10 +30,16 @@ template <index_t GridSize,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_E_N1_B_N2,
|
||||
class InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
class InBlockCopyThreadClusterArrangeOrder,
|
||||
class InBlockCopySrcAccessOrder,
|
||||
class InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopySrcDataPerRead_B,
|
||||
index_t InBlockCopyDstDataPerWrite_N2,
|
||||
class WeiBlockCopySubLengths_E_K,
|
||||
class WeiBlockCopyClusterLengths_E_K,
|
||||
class WeiBlockCopyThreadClusterArrangeOrder,
|
||||
class WeiBlockCopySrcAccessOrder,
|
||||
class WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K>
|
||||
struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
|
||||
@@ -146,19 +152,20 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_e_n1_b_n2_global_merged_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
Sequence<0, 1, 3, 2>, // thread_arrange_order [E, N1, N2, B]
|
||||
Sequence<0, 1, 3, 2>, // src_access_order [E, N1, N2, B]
|
||||
Sequence<0, 1, 2, 3>, // dst_access_order [E, N1, B, N2]
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
Float,
|
||||
decltype(in_e_n1_b_n2_global_merged_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2>(
|
||||
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
@@ -171,9 +178,10 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<mod_conv::max(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
#if 0
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
Float,
|
||||
@@ -182,12 +190,28 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
Sequence<1, 0>, // thread_arrange_order [K, E]
|
||||
Sequence<1, 0>, // src_access_order [K, E]
|
||||
Sequence<0, 1>, // dst_access_order [E, K]
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
#else
|
||||
constexpr auto map_k_e_2_e_k = Sequence<1, 0>{};
|
||||
|
||||
auto blockwise_wei_copy = BlockwiseTensorSliceReorderCopy_v3<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(wei_e_k_global_desc.ReorderGivenNew2Old(map_k_e_2_e_k)),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths().ReorderGivenNew2Old(map_k_e_2_e_k)),
|
||||
decltype(WeiBlockCopySubLengths_E_K::ReorderGivenNew2Old(map_k_e_2_e_k)),
|
||||
decltype(WeiBlockCopyClusterLengths_E_K::ReorderGivenNew2Old(map_k_e_2_e_k)),
|
||||
Sequence<1, 0>, // MapDst2Src
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>({k_block_data_on_global, 0}, {0, 0});
|
||||
#endif
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
@@ -261,7 +285,7 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
|
||||
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
|
||||
|
||||
#if 0
|
||||
if(e == 0 * EPerBlock && get_block_1d_id() == 0)
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("id %5u %5u: "
|
||||
"mThreadSrcOffset %u, mThreadDstOffset %u \n",
|
||||
@@ -272,6 +296,10 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0 // debug
|
||||
return;
|
||||
#endif
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
@@ -308,7 +336,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
#if 0
|
||||
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
#else
|
||||
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I1, Number<EPerBlock>{}, True);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -334,7 +366,11 @@ struct GridwiseConvolutionImplicitGemm_v4_lds_double_buffer_nchw_kcyx_nkhw
|
||||
|
||||
// even iteration
|
||||
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
#if 0
|
||||
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
#else
|
||||
blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I1, Number<EPerBlock>{}, True);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user