mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
fix v1r3 output reorder bug
[ROCm/composable_kernel commit: 63cdc6d2a4]
This commit is contained in:
@@ -87,7 +87,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
|
||||
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 1;
|
||||
|
||||
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used
|
||||
using WeiBlockCopyClusterLengths = void;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
|
||||
@@ -122,7 +122,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
|
||||
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 2;
|
||||
|
||||
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used
|
||||
using WeiBlockCopyClusterLengths = void;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_W = 4;
|
||||
@@ -136,10 +136,10 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 8;
|
||||
|
||||
constexpr index_t NPerThread = 2;
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 4;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
@@ -155,14 +155,14 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
|
||||
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 4, 8>;
|
||||
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
|
||||
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 1;
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 4;
|
||||
|
||||
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used
|
||||
using WeiBlockCopyClusterLengths = void;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_W = 1;
|
||||
#elif 0
|
||||
// for 3x3, 28x28, v1r2, Pascal
|
||||
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
|
||||
#elif 1
|
||||
// for 3x3, 28x28, v1r3, Pascal
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t NPerBlock = 16;
|
||||
@@ -186,13 +186,13 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 2>;
|
||||
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
|
||||
using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>;
|
||||
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
|
||||
constexpr index_t InBlockReorderDataPerRead_W = 2;
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 4;
|
||||
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 4;
|
||||
|
||||
using WeiBlockCopyClusterLengths = Sequence<4, 1, 32>;
|
||||
using WeiBlockCopyClusterLengths = void;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
|
||||
|
||||
@@ -371,7 +371,7 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
||||
std::size_t ho = HoPerTile * htile + j;
|
||||
for(int i = 0; i < WoPerTile; ++i)
|
||||
{
|
||||
std::size_t wo = WoPerTile * wtile + i;
|
||||
std::size_t wo = WoPerTile * wtile + i;
|
||||
out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
|
||||
}
|
||||
}
|
||||
@@ -413,13 +413,13 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
#if 1
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 34;
|
||||
constexpr index_t WI = 34;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
@@ -597,6 +597,8 @@ int main(int argc, char* argv[])
|
||||
};
|
||||
wei_kcyx.GenerateTensorValue(gen_wei, num_thread);
|
||||
#endif
|
||||
|
||||
// out_nkhw_device.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
}
|
||||
|
||||
#if 1
|
||||
|
||||
Reference in New Issue
Block a user