mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Update for recent MIOpen integration (#11)
* update for MIOpen integration
This commit is contained in:
@@ -49,7 +49,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
@@ -85,11 +84,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
"be violated");
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_n_k_howo_global_desc =
|
||||
unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3);
|
||||
|
||||
constexpr auto out_k_b_global_desc =
|
||||
transform_tensor_descriptor(out_n_k_howo_global_desc,
|
||||
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
@@ -353,7 +353,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
|
||||
}
|
||||
|
||||
{
|
||||
#if 1 // debug
|
||||
#if 1 // debug
|
||||
// input: register to global memory, atomic add
|
||||
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
|
||||
? InMemoryDataOperation::none
|
||||
|
||||
@@ -81,11 +81,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
@@ -115,10 +115,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
PassThrough<C>{},
|
||||
Embed<Y,
|
||||
Sequence<Ydot, Ytilda>,
|
||||
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>>{},
|
||||
Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>>{},
|
||||
Embed<X,
|
||||
Sequence<Xdot, Xtilda>,
|
||||
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>>{}),
|
||||
Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
@@ -135,10 +135,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
PassThrough<K>{},
|
||||
Embed<Ho,
|
||||
Sequence<Ydot, Htilda>,
|
||||
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
|
||||
Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>>{},
|
||||
Embed<Wo,
|
||||
Sequence<Xdot, Wtilda>,
|
||||
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
|
||||
Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
|
||||
@@ -110,11 +110,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
@@ -146,11 +146,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
PassThrough<C>{},
|
||||
Embed<Y,
|
||||
Sequence<Ydot, Ytilda>,
|
||||
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
|
||||
Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{},
|
||||
Embed<X,
|
||||
Sequence<Xdot, Xtilda>,
|
||||
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
|
||||
Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
@@ -168,11 +168,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
PassThrough<K>{},
|
||||
Embed<Ho,
|
||||
Sequence<Ydot, Htilda>,
|
||||
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
|
||||
Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{},
|
||||
Embed<Wo,
|
||||
Sequence<Xdot, Wtilda>,
|
||||
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
|
||||
Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
@@ -22,8 +22,6 @@ template <index_t GridSize,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t Iter_ytilda,
|
||||
index_t Iter_xtilda,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
@@ -47,9 +45,27 @@ template <index_t GridSize,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global) const
|
||||
__host__ __device__ static constexpr index_t GetNumberOfGemm()
|
||||
{
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
|
||||
return Ytilda * Xtilda;
|
||||
}
|
||||
|
||||
template <index_t iYTilda, index_t iXTilda>
|
||||
__device__ static void RunImpl(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global)
|
||||
{
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
@@ -83,11 +99,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
@@ -119,11 +135,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
PassThrough<C>{},
|
||||
Embed<Y,
|
||||
Sequence<Ydot, Ytilda>,
|
||||
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
|
||||
Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{},
|
||||
Embed<X,
|
||||
Sequence<Xdot, Xtilda>,
|
||||
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
|
||||
Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
@@ -141,11 +157,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
PassThrough<K>{},
|
||||
Embed<Ho,
|
||||
Sequence<Ydot, Htilda>,
|
||||
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
|
||||
Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{},
|
||||
Embed<Wo,
|
||||
Sequence<Xdot, Wtilda>,
|
||||
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
|
||||
Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
@@ -215,8 +231,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
// GEMM
|
||||
constexpr index_t ytilda = Iter_ytilda;
|
||||
constexpr index_t xtilda = Iter_xtilda;
|
||||
constexpr index_t ytilda = iYTilda;
|
||||
constexpr index_t xtilda = iXTilda;
|
||||
|
||||
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
|
||||
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
|
||||
@@ -327,6 +343,31 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
|
||||
}
|
||||
|
||||
template <index_t GemmId>
|
||||
__device__ static void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global)
|
||||
{
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
|
||||
constexpr index_t iYTilda = GemmId / Xtilda;
|
||||
constexpr index_t iXTilda = GemmId % Xtilda;
|
||||
|
||||
static_assert(iYTilda < Ytilda && iXTilda < Xtilda, "wrong! iYtilda, iXtilda");
|
||||
|
||||
RunImpl<iYTilda, iXTilda>(p_in_global, p_wei_global, p_out_global);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -49,7 +49,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
@@ -117,9 +116,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_k_b_global_desc =
|
||||
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
|
||||
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// GEMM
|
||||
|
||||
@@ -47,6 +47,9 @@ struct PassThrough
|
||||
}
|
||||
};
|
||||
|
||||
// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is
|
||||
// necessary
|
||||
// However, the check will be skipped if SkipIsValidCheck is set to true by user
|
||||
// LowerLengths: Sequence<...>
|
||||
template <typename LowerLengths,
|
||||
typename LeftPads,
|
||||
@@ -92,12 +95,12 @@ struct Pad
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
#if 1 // debug
|
||||
// skip valid check if user request it
|
||||
if(SkipIsValidCheck)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool flag = true;
|
||||
|
||||
for(index_t i = 0; i < nDim; ++i)
|
||||
@@ -384,6 +387,9 @@ struct UnMerge
|
||||
}
|
||||
};
|
||||
|
||||
// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is
|
||||
// necessary
|
||||
// However, the check will be skipped if SkipIsValidCheck is set to true by user
|
||||
// UpperLengths: Sequence<...>
|
||||
// Coefficients: Sequence<...>
|
||||
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
|
||||
@@ -442,12 +448,12 @@ struct Embed
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
#if 1 // debug
|
||||
// skip valid check if user request it
|
||||
if(SkipIsValidCheck)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool flag = true;
|
||||
|
||||
index_t ncorner = 1;
|
||||
|
||||
@@ -112,11 +112,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// has the valid/invalid mapping situation
|
||||
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(
|
||||
transfer_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(
|
||||
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
|
||||
}
|
||||
}
|
||||
@@ -144,11 +144,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// has the valid/invalid mapping situation
|
||||
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<DstData,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(
|
||||
transfer_data<DstData,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(
|
||||
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
|
||||
}
|
||||
}
|
||||
@@ -262,15 +262,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// has the valid/invalid mapping situation
|
||||
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(p_src,
|
||||
src_nonlinear_coord.GetOffset() +
|
||||
src_linear_offset,
|
||||
p_src_long_vector,
|
||||
buffer_offset);
|
||||
transfer_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(p_src,
|
||||
src_nonlinear_coord.GetOffset() +
|
||||
src_linear_offset,
|
||||
p_src_long_vector,
|
||||
buffer_offset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -301,11 +301,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// has the valid/invalid mapping situation
|
||||
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<DstData,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(
|
||||
transfer_data<DstData,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(
|
||||
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
|
||||
}
|
||||
}
|
||||
@@ -401,11 +401,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// has the valid/invalid mapping situation
|
||||
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(
|
||||
transfer_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(
|
||||
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
|
||||
}
|
||||
}
|
||||
@@ -446,14 +446,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// has the valid/invalid mapping situation
|
||||
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<DstData,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(p_dst_long_vector,
|
||||
buffer_offset,
|
||||
p_dst,
|
||||
dst_nonlinear_coord.GetOffset() + dst_linear_offset);
|
||||
transfer_data<DstData,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(p_dst_long_vector,
|
||||
buffer_offset,
|
||||
p_dst,
|
||||
dst_nonlinear_coord.GetOffset() +
|
||||
dst_linear_offset);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -8,19 +8,12 @@ namespace ck {
|
||||
// outer-product: c[i,j] += inner_product(a[i], b[j])
|
||||
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
|
||||
{
|
||||
// disable inline asm due to the compiler issue: SWDEV-202749
|
||||
///\to-do: enable the inline asm after the compiler fix
|
||||
#if CK_WORKAROUND_SWDEV_202749
|
||||
c0 += a * b0;
|
||||
c1 += a * b1;
|
||||
#else
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %2, %3 \n \
|
||||
v_mac_f32 %1, %2, %4 \n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1)
|
||||
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
|
||||
#endif
|
||||
}
|
||||
|
||||
// outer-product: c[i,j] += inner_product(a[i], b[j])
|
||||
|
||||
@@ -43,6 +43,10 @@
|
||||
#define CK_USE_AMD_XDLOPS_INLINE_ASM 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_USE_AMD_XDLOPS_EMULATE
|
||||
#define CK_USE_AMD_XDLOPS_EMULATE 0 // For internal debug purposes
|
||||
#endif
|
||||
|
||||
// experimental implementation
|
||||
#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 1
|
||||
#define CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF 0
|
||||
@@ -51,9 +55,6 @@
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
|
||||
|
||||
// workaround
|
||||
#define CK_WORKAROUND_SWDEV_202749 1
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum AddressSpace
|
||||
|
||||
@@ -70,7 +70,7 @@ template <typename T,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace,
|
||||
InMemoryDataOperation DstInMemOp>
|
||||
__device__ void move_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
{
|
||||
static_assert(DstInMemOp == InMemoryDataOperation::none ||
|
||||
DstInMemOp == InMemoryDataOperation::atomic_add,
|
||||
|
||||
@@ -38,7 +38,7 @@ template <typename T,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace,
|
||||
InMemoryDataOperation DstInMemOp>
|
||||
__device__ void move_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
{
|
||||
static_assert(DstInMemOp == InMemoryDataOperation::none ||
|
||||
DstInMemOp == InMemoryDataOperation::atomic_add,
|
||||
|
||||
@@ -103,9 +103,9 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
// highest common factor
|
||||
// greatest common divisor, aka highest common factor
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T hcf(T x, T y)
|
||||
__host__ __device__ constexpr T gcd(T x, T y)
|
||||
{
|
||||
if(x == 0)
|
||||
{
|
||||
@@ -124,30 +124,30 @@ __host__ __device__ constexpr T hcf(T x, T y)
|
||||
|
||||
if(x > y)
|
||||
{
|
||||
return hcf(x - y, y);
|
||||
return gcd(x - y, y);
|
||||
}
|
||||
|
||||
return hcf(x, y - x);
|
||||
return gcd(x, y - x);
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto hcf(Number<X>, Number<Y>)
|
||||
__host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
|
||||
{
|
||||
constexpr auto result = hcf(X, Y);
|
||||
constexpr auto result = gcd(X, Y);
|
||||
return Number<result>{};
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
__host__ __device__ constexpr auto hcf(X x, Ys... ys)
|
||||
__host__ __device__ constexpr auto gcd(X x, Ys... ys)
|
||||
{
|
||||
return hcf(x, ys...);
|
||||
return gcd(x, ys...);
|
||||
}
|
||||
|
||||
// least common multiple
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T lcm(T x, T y)
|
||||
{
|
||||
return (x * y) / hcf(x, y);
|
||||
return (x * y) / gcd(x, y);
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename... Zs>
|
||||
|
||||
@@ -152,11 +152,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
|
||||
@@ -91,11 +91,11 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
|
||||
@@ -2,13 +2,18 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename GridwiseOp, index_t GemmId, typename... Xs>
|
||||
__global__ void run_gridwise_convolution_backward_data_v4r1(Xs... xs)
|
||||
{
|
||||
GridwiseOp::template Run<GemmId>(xs...);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
@@ -119,11 +124,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
@@ -154,69 +159,61 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
KernelTimer timer;
|
||||
using GridwiseConv = GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
static_for<0, Ytilda, 1>{}([&](auto ytilda_) {
|
||||
static_for<0, Xtilda, 1>{}([&](auto xtilda_) {
|
||||
constexpr index_t ytilda = decltype(ytilda_){};
|
||||
constexpr index_t xtilda = decltype(xtilda_){};
|
||||
static_for<0, GridwiseConv::GetNumberOfGemm(), 1>{}([&](auto gemm_id_) {
|
||||
constexpr index_t gemm_id = decltype(gemm_id_){};
|
||||
|
||||
constexpr auto gridwise_conv =
|
||||
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ytilda,
|
||||
xtilda,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
gridwise_conv,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
});
|
||||
launch_kernel(run_gridwise_convolution_backward_data_v4r1<GridwiseConv,
|
||||
gemm_id,
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
});
|
||||
|
||||
timer.End();
|
||||
|
||||
float time = timer.GetElapsedTime();
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
|
||||
@@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// BlockSize = 256, EperBlock = 8, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -127,7 +127,45 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
|
||||
#elif 1
|
||||
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data
|
||||
// for 1x1
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<4, 1, 1, 2>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 2>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
|
||||
#elif 1
|
||||
// BlockSize = 64, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
@@ -84,7 +84,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// BlockSize = 256, GemmKPerBlock = 16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -117,7 +117,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 0
|
||||
// BlockSize = 256, GemmKPerBlock = 8
|
||||
// 1x1 filter, 8x8 image
|
||||
// for 1x1 filter, vector-read-b = 4
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
@@ -149,7 +149,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#elif 1
|
||||
// BlockSize = 256, GemmKPerBlock = 16
|
||||
// 1x1 filter, 8x8 image
|
||||
// for 1x1 filter, vector-read-b = 4
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
|
||||
@@ -161,10 +161,10 @@ int main(int argc, char* argv[])
|
||||
#elif 1
|
||||
// 1x7 filter, 0x3 pad, 17x17 input
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t C = 128;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 7;
|
||||
|
||||
@@ -246,28 +246,28 @@ int main(int argc, char* argv[])
|
||||
#endif
|
||||
}
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw
|
||||
#elif 1
|
||||
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
|
||||
#endif
|
||||
(in_nchw_desc,
|
||||
in_nchw_device,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
(in_nchw_desc,
|
||||
in_nchw_device,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
|
||||
@@ -29,13 +29,13 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// 1x1
|
||||
constexpr index_t N = 256;
|
||||
constexpr index_t C = 1024;
|
||||
constexpr index_t HI = 8;
|
||||
constexpr index_t WI = 8;
|
||||
constexpr index_t K = 1024;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user