mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
experimenting new merged tensor copy
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
#include "threadwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_batched_gemm.hpp"
|
||||
@@ -122,8 +123,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
#if 0
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
@@ -132,6 +134,39 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#elif 0
|
||||
using InBlockCopySubLengths_CHWN =
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths() / InBlockCopyClusterLengths_CHWN{});
|
||||
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_CHWN,
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
1,
|
||||
1>({0, 0, 0, 0}, {0, 0, 0, 0});
|
||||
#elif 1
|
||||
using InBlockCopySubLengths_CHWN =
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths() / InBlockCopyClusterLengths_CHWN{});
|
||||
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
NormalTensorCoordinate<decltype(in_c_h_w_n_global_desc)>,
|
||||
NormalTensorCoordinate<decltype(in_c_h_w_n_block_desc)>,
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_CHWN,
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
Sequence<0, 1, 2, 3>>({0, 0, 0, 0}, {0, 0, 0, 0});
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, X * KPerBlock]
|
||||
|
||||
@@ -224,7 +224,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
@@ -258,9 +258,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr index_t in_block_space =
|
||||
in_e_n1_b_n2_block_desc.GetElementSpace(Number<max_align>{});
|
||||
math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
constexpr index_t wei_block_space = wei_e_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t wei_block_space =
|
||||
math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align);
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
@@ -161,7 +161,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
Float,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
MergedTensorCoordinate<decltype(wei_e_k_global_desc)>,
|
||||
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
|
||||
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>,
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
@@ -301,7 +301,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
b_thread_data_on_global % B1});
|
||||
|
||||
threadwise_out_copy.Run(p_out_thread, p_out_thread_on_global);
|
||||
#else
|
||||
#elif 0
|
||||
// This is a hack, because slicing a merged dimension is not supported yet.
|
||||
// This should be replaced with logic above, once slicing a merged dimension support
|
||||
// become available
|
||||
@@ -328,6 +328,49 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
k_thread_data_on_global % K1,
|
||||
b_thread_data_on_global});
|
||||
|
||||
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
|
||||
{
|
||||
threadwise_out_copy.Run(p_out_thread, p_out_global);
|
||||
|
||||
threadwise_out_copy.MoveSrcSlicingWindow({0, 0, GemmNPerThreadSubC}, true);
|
||||
threadwise_out_copy.MoveDstSlicingWindow({0, 0, B1}, true);
|
||||
}
|
||||
#elif 1
|
||||
// This is a hack, because slicing a merged dimension is not supported yet.
|
||||
// This should be replaced with logic above, once slicing a merged dimension support
|
||||
// become available
|
||||
// dst descriptor
|
||||
constexpr auto out_k0_k1_b_global_desc =
|
||||
make_ConstantMergedTensorDescriptor(out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}),
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<0, 3, 4>{});
|
||||
|
||||
// src descriptor
|
||||
constexpr auto out_k0_k1_b_thread_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat * GemmNPerThreadSubC>{});
|
||||
|
||||
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2<
|
||||
Float,
|
||||
#if 1 // debug
|
||||
decltype(out_k0_k1_b_thread_desc),
|
||||
decltype(out_k0_k1_b_global_desc),
|
||||
NormalTensorCoordinate<decltype(out_k0_k1_b_thread_desc)>,
|
||||
MergedTensorCoordinate<decltype(out_k0_k1_b_global_desc)>,
|
||||
#else
|
||||
decltype(out_k0_k1_b_thread_desc),
|
||||
decltype(
|
||||
make_ConstantTensorDescriptor_packed(out_k0_k1_b_global_desc.GetLengths())),
|
||||
NormalTensorCoordinate<decltype(out_k0_k1_b_thread_desc)>,
|
||||
NormalTensorCoordinate<decltype(
|
||||
make_ConstantTensorDescriptor_packed(out_k0_k1_b_global_desc.GetLengths()))>,
|
||||
#endif
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>>(
|
||||
{0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
b_thread_data_on_global});
|
||||
|
||||
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
|
||||
{
|
||||
threadwise_out_copy.Run(p_out_thread, p_out_global);
|
||||
|
||||
@@ -93,6 +93,17 @@ struct NormalTensorCoordinate
|
||||
return coord;
|
||||
}
|
||||
|
||||
// reposition point of origin, and return compensated offset
|
||||
__host__ __device__ constexpr index_t RepositionOrigin()
|
||||
{
|
||||
index_t offset_diff = mOffset;
|
||||
|
||||
mIndex = make_zero_array<index_t, nDim>();
|
||||
mOffset = 0;
|
||||
|
||||
return offset_diff;
|
||||
}
|
||||
|
||||
// private:
|
||||
Array<index_t, nDim> mIndex;
|
||||
index_t mOffset;
|
||||
@@ -305,6 +316,29 @@ struct MergedTensorCoordinate
|
||||
return coord;
|
||||
}
|
||||
|
||||
// reposition point of origin, and return compensated offset
|
||||
__host__ __device__ constexpr index_t RepositionOrigin()
|
||||
{
|
||||
index_t offset_diff = 0;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim_) {
|
||||
constexpr auto idim = decltype(idim_){};
|
||||
|
||||
static_if<!tensor_desc_type::ContainMultipleOriginalDimensions(idim)>{}([&](auto) {
|
||||
constexpr auto idim_original =
|
||||
tensor_desc_type::GetContainedOriginalDimensions(idim).Front();
|
||||
|
||||
mIndex(idim) = 0;
|
||||
mOriginalIndex(idim_original) = 0;
|
||||
mOffset -= mPartialOffsets[idim];
|
||||
offset_diff += mPartialOffsets[idim];
|
||||
mPartialOffsets(idim) = 0;
|
||||
});
|
||||
});
|
||||
|
||||
return offset_diff;
|
||||
}
|
||||
|
||||
// private:
|
||||
Array<index_t, nDim> mIndex;
|
||||
Array<index_t, nOriginalDim> mOriginalIndex;
|
||||
@@ -312,18 +346,5 @@ struct MergedTensorCoordinate
|
||||
index_t mOffset;
|
||||
};
|
||||
|
||||
#if 0
|
||||
// implementation of MergedTensorCoordinate, when index_t is signed integer
|
||||
// mPartialOffsets is not needed, if index_t is signed integer type
|
||||
template<>
|
||||
struct TensorCoordinate<signed_t>
|
||||
{
|
||||
private:
|
||||
Array<_t, nDim> mIndex;
|
||||
Array<_t, nOriginalDim> mOriginalIndex;
|
||||
index_t mOffset;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -140,11 +140,23 @@ struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
#if 0
|
||||
static_ford<SliceLengths>{}([&](auto data_id) {
|
||||
p_dst[(mDstSliceOrigin + data_id).GetOffset()] =
|
||||
p_src[(mSrcSliceOrigin + data_id).GetOffset()];
|
||||
|
||||
});
|
||||
#elif 1
|
||||
auto src_slice_origin = mSrcSliceOrigin;
|
||||
auto dst_slice_origin = mDstSliceOrigin;
|
||||
|
||||
p_src += src_slice_origin.RepositionOrigin();
|
||||
p_dst += dst_slice_origin.RepositionOrigin();
|
||||
|
||||
static_ford<SliceLengths>{}([&](auto data_id) {
|
||||
p_dst[(dst_slice_origin + data_id).GetOffset()] =
|
||||
p_src[(src_slice_origin + data_id).GetOffset()];
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSlicingWindow(Array<index_t, nDim> step_sizes, bool positive_direction)
|
||||
|
||||
@@ -478,9 +478,9 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
#elif 1
|
||||
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
#endif
|
||||
<GridSize,
|
||||
|
||||
@@ -85,7 +85,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
@@ -520,7 +520,7 @@ int main(int argc, char* argv[])
|
||||
#if 0
|
||||
device_convolution_direct_v2_nchw_kcyx_nkhw
|
||||
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(
|
||||
in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
#elif 0
|
||||
|
||||
Reference in New Issue
Block a user