mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
Update for recent MIOpen integration (#11)
* update for MIOpen integration
This commit is contained in:
@@ -49,7 +49,6 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
@@ -85,11 +84,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
|
||||
"be violated");
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_n_k_howo_global_desc =
|
||||
unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3);
|
||||
|
||||
constexpr auto out_k_b_global_desc =
|
||||
transform_tensor_descriptor(out_n_k_howo_global_desc,
|
||||
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
@@ -353,7 +353,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl
|
||||
}
|
||||
|
||||
{
|
||||
#if 1 // debug
|
||||
#if 1 // debug
|
||||
// input: register to global memory, atomic add
|
||||
constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW)
|
||||
? InMemoryDataOperation::none
|
||||
|
||||
@@ -81,11 +81,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
@@ -115,10 +115,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
PassThrough<C>{},
|
||||
Embed<Y,
|
||||
Sequence<Ydot, Ytilda>,
|
||||
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>>{},
|
||||
Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>>{},
|
||||
Embed<X,
|
||||
Sequence<Xdot, Xtilda>,
|
||||
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>>{}),
|
||||
Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
@@ -135,10 +135,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
|
||||
PassThrough<K>{},
|
||||
Embed<Ho,
|
||||
Sequence<Ydot, Htilda>,
|
||||
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>>{},
|
||||
Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>>{},
|
||||
Embed<Wo,
|
||||
Sequence<Xdot, Wtilda>,
|
||||
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>>{}),
|
||||
Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
|
||||
@@ -110,11 +110,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
@@ -146,11 +146,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
PassThrough<C>{},
|
||||
Embed<Y,
|
||||
Sequence<Ydot, Ytilda>,
|
||||
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
|
||||
Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{},
|
||||
Embed<X,
|
||||
Sequence<Xdot, Xtilda>,
|
||||
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
|
||||
Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
@@ -168,11 +168,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw
|
||||
PassThrough<K>{},
|
||||
Embed<Ho,
|
||||
Sequence<Ydot, Htilda>,
|
||||
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
|
||||
Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{},
|
||||
Embed<Wo,
|
||||
Sequence<Xdot, Wtilda>,
|
||||
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
|
||||
Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
@@ -22,8 +22,6 @@ template <index_t GridSize,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t Iter_ytilda,
|
||||
index_t Iter_xtilda,
|
||||
index_t GemmMPerBlock,
|
||||
index_t GemmNPerBlock,
|
||||
index_t GemmKPerBlock,
|
||||
@@ -47,9 +45,27 @@ template <index_t GridSize,
|
||||
index_t GemmCThreadCopyDstDataPerWrite_GemmN1>
|
||||
struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global) const
|
||||
__host__ __device__ static constexpr index_t GetNumberOfGemm()
|
||||
{
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
|
||||
return Ytilda * Xtilda;
|
||||
}
|
||||
|
||||
template <index_t iYTilda, index_t iXTilda>
|
||||
__device__ static void RunImpl(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global)
|
||||
{
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
@@ -83,11 +99,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
"be violated");
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
@@ -119,11 +135,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
PassThrough<C>{},
|
||||
Embed<Y,
|
||||
Sequence<Ydot, Ytilda>,
|
||||
Sequence<ConvStrideH / hcf_stride_dilation_h, 1, 0>,
|
||||
Sequence<ConvStrideH / gcd_stride_dilation_h, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{},
|
||||
Embed<X,
|
||||
Sequence<Xdot, Xtilda>,
|
||||
Sequence<ConvStrideW / hcf_stride_dilation_w, 1, 0>,
|
||||
Sequence<ConvStrideW / gcd_stride_dilation_w, 1, 0>,
|
||||
wei_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
@@ -141,11 +157,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
PassThrough<K>{},
|
||||
Embed<Ho,
|
||||
Sequence<Ydot, Htilda>,
|
||||
Sequence<-ConvDilationH / hcf_stride_dilation_h, 1, 0>,
|
||||
Sequence<-ConvDilationH / gcd_stride_dilation_h, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{},
|
||||
Embed<Wo,
|
||||
Sequence<Xdot, Wtilda>,
|
||||
Sequence<-ConvDilationW / hcf_stride_dilation_w, 1, 0>,
|
||||
Sequence<-ConvDilationW / gcd_stride_dilation_w, 1, 0>,
|
||||
out_skip_all_out_of_bound_check>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
@@ -215,8 +231,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}));
|
||||
|
||||
// GEMM
|
||||
constexpr index_t ytilda = Iter_ytilda;
|
||||
constexpr index_t xtilda = Iter_xtilda;
|
||||
constexpr index_t ytilda = iYTilda;
|
||||
constexpr index_t xtilda = iXTilda;
|
||||
|
||||
constexpr index_t YdotNonZero = (ytilda + 1) * Ydot <= Y ? Ydot : Y % Ydot;
|
||||
constexpr index_t XdotNonZero = (xtilda + 1) * Xdot <= X ? Xdot : X % Xdot;
|
||||
@@ -327,6 +343,31 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
|
||||
gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global);
|
||||
}
|
||||
|
||||
template <index_t GemmId>
|
||||
__device__ static void Run(Float* __restrict__ p_in_global,
|
||||
const Float* __restrict__ p_wei_global,
|
||||
const Float* __restrict__ p_out_global)
|
||||
{
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr index_t gcd_stride_dilation_h = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t gcd_stride_dilation_w = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / gcd_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / gcd_stride_dilation_w;
|
||||
|
||||
constexpr index_t iYTilda = GemmId / Xtilda;
|
||||
constexpr index_t iXTilda = GemmId % Xtilda;
|
||||
|
||||
static_assert(iYTilda < Ytilda && iXTilda < Xtilda, "wrong! iYtilda, iXtilda");
|
||||
|
||||
RunImpl<iYTilda, iXTilda>(p_in_global, p_wei_global, p_out_global);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -49,7 +49,6 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
@@ -117,9 +116,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
|
||||
// output tensor
|
||||
constexpr auto out_k_b_global_desc =
|
||||
transform_tensor_descriptor(out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
|
||||
transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3),
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho * Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// GEMM
|
||||
|
||||
@@ -47,6 +47,9 @@ struct PassThrough
|
||||
}
|
||||
};
|
||||
|
||||
// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is
|
||||
// necessary
|
||||
// However, the check will be skipped if SkipIsValidCheck is set to true by user
|
||||
// LowerLengths: Sequence<...>
|
||||
template <typename LowerLengths,
|
||||
typename LeftPads,
|
||||
@@ -92,12 +95,12 @@ struct Pad
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
#if 1 // debug
|
||||
// skip valid check if user request it
|
||||
if(SkipIsValidCheck)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool flag = true;
|
||||
|
||||
for(index_t i = 0; i < nDim; ++i)
|
||||
@@ -384,6 +387,9 @@ struct UnMerge
|
||||
}
|
||||
};
|
||||
|
||||
// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is
|
||||
// necessary
|
||||
// However, the check will be skipped if SkipIsValidCheck is set to true by user
|
||||
// UpperLengths: Sequence<...>
|
||||
// Coefficients: Sequence<...>
|
||||
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp]
|
||||
@@ -442,12 +448,12 @@ struct Embed
|
||||
|
||||
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
|
||||
{
|
||||
#if 1 // debug
|
||||
// skip valid check if user request it
|
||||
if(SkipIsValidCheck)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool flag = true;
|
||||
|
||||
index_t ncorner = 1;
|
||||
|
||||
@@ -112,11 +112,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// has the valid/invalid mapping situation
|
||||
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(
|
||||
transfer_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(
|
||||
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
|
||||
}
|
||||
}
|
||||
@@ -144,11 +144,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// has the valid/invalid mapping situation
|
||||
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<DstData,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(
|
||||
transfer_data<DstData,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(
|
||||
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
|
||||
}
|
||||
}
|
||||
@@ -262,15 +262,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// has the valid/invalid mapping situation
|
||||
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(p_src,
|
||||
src_nonlinear_coord.GetOffset() +
|
||||
src_linear_offset,
|
||||
p_src_long_vector,
|
||||
buffer_offset);
|
||||
transfer_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(p_src,
|
||||
src_nonlinear_coord.GetOffset() +
|
||||
src_linear_offset,
|
||||
p_src_long_vector,
|
||||
buffer_offset);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -301,11 +301,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// has the valid/invalid mapping situation
|
||||
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<DstData,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(
|
||||
transfer_data<DstData,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(
|
||||
p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset());
|
||||
}
|
||||
}
|
||||
@@ -401,11 +401,11 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// has the valid/invalid mapping situation
|
||||
if(src_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(
|
||||
transfer_data<SrcData,
|
||||
SrcDataPerRead,
|
||||
SrcAddressSpace,
|
||||
AddressSpace::vgpr,
|
||||
InMemoryDataOperation::none>(
|
||||
p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset);
|
||||
}
|
||||
}
|
||||
@@ -446,14 +446,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// has the valid/invalid mapping situation
|
||||
if(dst_coord.IsOffsetValidAssumingUpperIndexIsValid())
|
||||
{
|
||||
move_data<DstData,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(p_dst_long_vector,
|
||||
buffer_offset,
|
||||
p_dst,
|
||||
dst_nonlinear_coord.GetOffset() + dst_linear_offset);
|
||||
transfer_data<DstData,
|
||||
DstDataPerWrite,
|
||||
AddressSpace::vgpr,
|
||||
DstAddressSpace,
|
||||
DstInMemOp>(p_dst_long_vector,
|
||||
buffer_offset,
|
||||
p_dst,
|
||||
dst_nonlinear_coord.GetOffset() +
|
||||
dst_linear_offset);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -8,19 +8,12 @@ namespace ck {
|
||||
// outer-product: c[i,j] += inner_product(a[i], b[j])
|
||||
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
|
||||
{
|
||||
// disable inline asm due to the compiler issue: SWDEV-202749
|
||||
///\to-do: enable the inline asm after the compiler fix
|
||||
#if CK_WORKAROUND_SWDEV_202749
|
||||
c0 += a * b0;
|
||||
c1 += a * b1;
|
||||
#else
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %2, %3 \n \
|
||||
v_mac_f32 %1, %2, %4 \n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1)
|
||||
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
|
||||
#endif
|
||||
}
|
||||
|
||||
// outer-product: c[i,j] += inner_product(a[i], b[j])
|
||||
|
||||
@@ -43,6 +43,10 @@
|
||||
#define CK_USE_AMD_XDLOPS_INLINE_ASM 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_USE_AMD_XDLOPS_EMULATE
|
||||
#define CK_USE_AMD_XDLOPS_EMULATE 0 // For internal debug purposes
|
||||
#endif
|
||||
|
||||
// experimental implementation
|
||||
#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 1
|
||||
#define CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF 0
|
||||
@@ -51,9 +55,6 @@
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
|
||||
|
||||
// workaround
|
||||
#define CK_WORKAROUND_SWDEV_202749 1
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum AddressSpace
|
||||
|
||||
@@ -70,7 +70,7 @@ template <typename T,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace,
|
||||
InMemoryDataOperation DstInMemOp>
|
||||
__device__ void move_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
{
|
||||
static_assert(DstInMemOp == InMemoryDataOperation::none ||
|
||||
DstInMemOp == InMemoryDataOperation::atomic_add,
|
||||
|
||||
@@ -38,7 +38,7 @@ template <typename T,
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace,
|
||||
InMemoryDataOperation DstInMemOp>
|
||||
__device__ void move_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset)
|
||||
{
|
||||
static_assert(DstInMemOp == InMemoryDataOperation::none ||
|
||||
DstInMemOp == InMemoryDataOperation::atomic_add,
|
||||
|
||||
@@ -103,9 +103,9 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
// highest common factor
|
||||
// greatest common divisor, aka highest common factor
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T hcf(T x, T y)
|
||||
__host__ __device__ constexpr T gcd(T x, T y)
|
||||
{
|
||||
if(x == 0)
|
||||
{
|
||||
@@ -124,30 +124,30 @@ __host__ __device__ constexpr T hcf(T x, T y)
|
||||
|
||||
if(x > y)
|
||||
{
|
||||
return hcf(x - y, y);
|
||||
return gcd(x - y, y);
|
||||
}
|
||||
|
||||
return hcf(x, y - x);
|
||||
return gcd(x, y - x);
|
||||
}
|
||||
|
||||
template <index_t X, index_t Y>
|
||||
__host__ __device__ constexpr auto hcf(Number<X>, Number<Y>)
|
||||
__host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
|
||||
{
|
||||
constexpr auto result = hcf(X, Y);
|
||||
constexpr auto result = gcd(X, Y);
|
||||
return Number<result>{};
|
||||
}
|
||||
|
||||
template <typename X, typename... Ys>
|
||||
__host__ __device__ constexpr auto hcf(X x, Ys... ys)
|
||||
__host__ __device__ constexpr auto gcd(X x, Ys... ys)
|
||||
{
|
||||
return hcf(x, ys...);
|
||||
return gcd(x, ys...);
|
||||
}
|
||||
|
||||
// least common multiple
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T lcm(T x, T y)
|
||||
{
|
||||
return (x * y) / hcf(x, y);
|
||||
return (x * y) / gcd(x, y);
|
||||
}
|
||||
|
||||
template <typename X, typename Y, typename... Zs>
|
||||
|
||||
Reference in New Issue
Block a user