mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
added strides and dilations suppport to implicit gemm v4
This commit is contained in:
@@ -22,6 +22,8 @@ template <index_t GridSize,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
@@ -117,15 +119,17 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
|
||||
// input tensor
|
||||
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
|
||||
constexpr auto in_n0_n1_n2_h_w_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Ho>{})
|
||||
.Slice(I3, Number<Wo>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{})
|
||||
.Extract(Sequence<0, 1, 2, 4, 5>{});
|
||||
constexpr auto in_n0_n1_n2_h_w_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrides::Get(I0)>{})
|
||||
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrides::Get(I1)>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{})
|
||||
.Extract(Sequence<0, 1, 2, 4, 5>{});
|
||||
|
||||
// batch descritpor for device memory
|
||||
constexpr auto in_c_y_x_global_desc = in_n_c_h_w_global_desc.Slice(I2, Number<Y>{})
|
||||
.Slice(I3, Number<X>{})
|
||||
.Extract(Sequence<1, 2, 3>{});
|
||||
constexpr auto in_c_y_x_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilations::Get(I0)>{})
|
||||
.StridedSlice(I3, Number<X>{}, Number<ConvDilations::Get(I1)>{})
|
||||
.Extract(Sequence<1, 2, 3>{});
|
||||
|
||||
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
|
||||
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
|
||||
@@ -320,6 +320,18 @@ struct ConstantTensorDescriptor
|
||||
return ConstantTensorDescriptor<slice_lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t SliceLength, index_t SliceStride>
|
||||
__host__ __device__ static constexpr auto
|
||||
StridedSlice(Number<IDim>, Number<SliceLength>, Number<SliceStride>)
|
||||
{
|
||||
constexpr index_t new_stride = Strides::Get(Number<IDim>{}) * SliceStride;
|
||||
|
||||
using new_lengths = decltype(Lengths::Modify(Number<IDim>{}, Number<SliceLength>{}));
|
||||
using new_strides = decltype(Strides::Modify(Number<IDim>{}, Number<new_stride>{}));
|
||||
|
||||
return ConstantTensorDescriptor<new_lengths, new_strides>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t... FoldIntervals>
|
||||
__host__ __device__ static constexpr auto Fold(Number<IDim>, Number<FoldIntervals>...)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user