mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
@@ -335,15 +335,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
#if 0
|
||||
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
// blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{},
|
||||
// True);
|
||||
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
|
||||
#else
|
||||
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
#endif
|
||||
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -367,14 +360,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()];
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
#if 0
|
||||
blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
// blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number<EPerBlock>{}, True);
|
||||
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
|
||||
#else
|
||||
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
#endif
|
||||
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -447,7 +434,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex(
|
||||
k_thread_data_on_global, 0, b_thread_data_on_global, 0);
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2<
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_thread_desc),
|
||||
decltype(out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc),
|
||||
@@ -469,8 +456,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
7,
|
||||
7,
|
||||
1,
|
||||
1>(
|
||||
{0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0})
|
||||
1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0})
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -244,6 +244,8 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
|
||||
|
||||
const Float* p_wei_block_on_global = p_wei_global;
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global, p_in_block_double);
|
||||
@@ -273,13 +275,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()];
|
||||
|
||||
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStrides()[0];
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global, p_wei_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global,
|
||||
p_wei_register_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
@@ -297,13 +300,13 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
|
||||
// even iteration
|
||||
blockwise_in_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSlicingWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStrides()[0];
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global, p_wei_register_buffer);
|
||||
blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
@@ -237,7 +237,10 @@ struct MergedTensorCoordinate
|
||||
index_t normal_offset_diff = 0;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, true>{});
|
||||
if(step_sizes[idim] != 0)
|
||||
{
|
||||
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, true>{});
|
||||
}
|
||||
});
|
||||
|
||||
return *this;
|
||||
@@ -249,7 +252,10 @@ struct MergedTensorCoordinate
|
||||
static_assert(is_same<typename T::data_type, index_t>{} && T::GetSize() == nDim, "wrong!");
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, false>{});
|
||||
if(step_sizes[idim] != 0)
|
||||
{
|
||||
this->MoveOnDimension(idim, step_sizes[idim], integral_constant<bool, false>{});
|
||||
}
|
||||
});
|
||||
|
||||
return *this;
|
||||
|
||||
@@ -402,6 +402,19 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <class T, bool PositiveDirection>
|
||||
__device__ void
|
||||
MoveSrcSlicingWindow(T step_sizes,
|
||||
integral_constant<bool, PositiveDirection> positive_direction)
|
||||
{
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
if(step_sizes[idim] != 0)
|
||||
{
|
||||
MoveSlicingWindowOnSourceTensor(idim, step_sizes[idim], positive_direction);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t BlockSize,
|
||||
@@ -502,21 +515,6 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
private:
|
||||
using RegisterBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
|
||||
|
||||
#if 0
|
||||
using ThreadwiseLoad =
|
||||
ThreadwiseGenericTensorSliceCopy_v2<SrcDesc,
|
||||
RegisterBufferDesc,
|
||||
SrcCoordinate,
|
||||
NormalTensorCoordinate<RegisterBufferDesc>,
|
||||
SubLengths>;
|
||||
|
||||
using ThreadwiseStore =
|
||||
ThreadwiseGenericTensorSliceCopy_v2<RegisterBufferDesc,
|
||||
DstDesc,
|
||||
NormalTensorCoordinate<RegisterBufferDesc>,
|
||||
DstCoordinate,
|
||||
SubLengths>;
|
||||
#else
|
||||
using ThreadwiseLoad =
|
||||
ThreadwiseGenericTensorSliceCopy_v2r1<SrcDesc,
|
||||
RegisterBufferDesc,
|
||||
@@ -542,7 +540,7 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>;
|
||||
#endif
|
||||
|
||||
ThreadwiseLoad mThreadwiseLoad;
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
};
|
||||
|
||||
@@ -594,7 +594,6 @@ struct ThreadwiseGenericTensorSliceCopy_v2
|
||||
DstCoordinate mDstSliceOrigin;
|
||||
};
|
||||
|
||||
#if 1
|
||||
// This threadwise copy allow vector access of src and dst.
|
||||
// It allows the dimensions of vector access to be different on src and dst.
|
||||
// It also allows the vector size to be different on src and dst.
|
||||
@@ -623,6 +622,49 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
|
||||
DstCoordinate dst_slice_origin)
|
||||
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
|
||||
{
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
|
||||
nDim == SrcDimAccessOrder::GetSize() &&
|
||||
nDim == DstDimAccessOrder::GetSize(),
|
||||
"wrong! # of dimensions not the same");
|
||||
|
||||
static_assert(is_valid_sequence_map<SrcDimAccessOrder>::value &&
|
||||
is_valid_sequence_map<DstDimAccessOrder>::value,
|
||||
"wrong! map is not valid");
|
||||
|
||||
static_assert(SliceLengths{}[SrcVectorAccessDim] % SrcDataPerAccess == 0 &&
|
||||
SliceLengths{}[DstVectorAccessDim] % DstDataPerAccess == 0,
|
||||
"wrong! cannot evenly divide");
|
||||
|
||||
// check vectorized memory access
|
||||
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
|
||||
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
|
||||
|
||||
static_if<!SrcDesc::ContainMultipleOriginalDimensions(src_vector_access_dim)>{}(
|
||||
[&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(SrcDesc{}).GetStride(src_vector_access_dim) == 1 || SrcDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
})
|
||||
.Else([&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(SrcDesc{}).GetLastOriginalDimensionStride(src_vector_access_dim) == 1 ||
|
||||
SrcDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
});
|
||||
|
||||
static_if<!DstDesc::ContainMultipleOriginalDimensions(dst_vector_access_dim)>{}(
|
||||
[&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(DstDesc{}).GetStride(dst_vector_access_dim) == 1 || DstDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
})
|
||||
.Else([&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(DstDesc{}).GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 ||
|
||||
DstDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
});
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1()
|
||||
@@ -725,9 +767,6 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
|
||||
constexpr index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex(
|
||||
src_merged_dim_data_id + src_normal_dim_data_id + scalar_id);
|
||||
|
||||
constexpr index_t buffer_offset =
|
||||
buffer_desc.GetOffsetFromMultiIndex(src_data_begin_id + scalar_id);
|
||||
|
||||
p_buffer[buffer_offset] = reinterpret_cast<const TData*>(&vector_data)[i];
|
||||
});
|
||||
});
|
||||
@@ -900,7 +939,6 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
|
||||
SrcCoordinate mSrcSliceOrigin;
|
||||
DstCoordinate mDstSliceOrigin;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user