mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
adding implicit gemm v3
This commit is contained in:
@@ -47,8 +47,8 @@ void device_convolution_direct_v2_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 1;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 1;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#endif
|
||||
|
||||
@@ -92,7 +92,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
|
||||
@@ -85,6 +85,9 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_C_K = Sequence<1, 4>;
|
||||
using WeiBlockCopyClusterLengths_C_K = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerAccess_K = 4;
|
||||
#endif
|
||||
|
||||
@@ -123,8 +126,11 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
|
||||
InBlockCopyClusterLengths_N1_N2_C_B,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopySubLengths_C_K,
|
||||
WeiBlockCopyClusterLengths_C_K,
|
||||
WeiBlockCopyDataPerAccess_K>{};
|
||||
|
||||
#if 1
|
||||
float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
@@ -138,6 +144,7 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
|
||||
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
#endif
|
||||
}
|
||||
|
||||
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
|
||||
|
||||
@@ -411,7 +411,18 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if 1
|
||||
#if 0
|
||||
constexpr index_t N = 8;
|
||||
constexpr index_t C = 8;
|
||||
constexpr index_t HI = 3;
|
||||
constexpr index_t WI = 18;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 1
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
@@ -635,11 +646,13 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
#if 1
|
||||
if(Y == 3 && X == 3)
|
||||
{
|
||||
host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
host_direct_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user