mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
nchw*cyxk*nkhw on AMD
This commit is contained in:
@@ -217,9 +217,9 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
|
||||
constexpr auto gridwise_conv =
|
||||
#if 0
|
||||
GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
#elif 1
|
||||
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
|
||||
#elif 1
|
||||
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
|
||||
#endif
|
||||
<GridSize,
|
||||
|
||||
@@ -57,6 +57,111 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 0
|
||||
// for 3x3, 34x34, v1r3, Pascal
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 16;
|
||||
|
||||
constexpr index_t NPerThread = 2;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>;
|
||||
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 1, 16>;
|
||||
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
|
||||
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 1;
|
||||
|
||||
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
|
||||
#elif 0
|
||||
// for 3x3, 34x34, v1r3, Vega 20
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 16;
|
||||
|
||||
constexpr index_t NPerThread = 2;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>;
|
||||
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 2, 16>;
|
||||
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
|
||||
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 2;
|
||||
|
||||
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_W = 4;
|
||||
#elif 1
|
||||
// for 3x3, 34x34, v1r3, Vega 20, try
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t NPerBlock = 4;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 8;
|
||||
|
||||
constexpr index_t NPerThread = 2;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
|
||||
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 4, 8>;
|
||||
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
|
||||
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 1;
|
||||
|
||||
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_W = 1;
|
||||
#elif 0
|
||||
// for 3x3, 28x28, v1r2, Pascal
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
@@ -90,76 +195,6 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
|
||||
using WeiBlockCopyClusterLengths = Sequence<4, 1, 32>;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
|
||||
#elif 0
|
||||
// for 3x3, 28x28, v1r3, Pascal, bad
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
|
||||
using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>;
|
||||
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
|
||||
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 1; // not used yet
|
||||
|
||||
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
|
||||
#elif 1
|
||||
// for 3x3, 34x34, v1r3, Pascal
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 16;
|
||||
|
||||
constexpr index_t NPerThread = 2;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>;
|
||||
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 1, 16>;
|
||||
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
|
||||
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 1; // not used yet
|
||||
|
||||
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_W = 2;
|
||||
#endif
|
||||
|
||||
|
||||
@@ -608,7 +608,7 @@ int main(int argc, char* argv[])
|
||||
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
|
||||
|
||||
@@ -203,520 +203,520 @@ __device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, in
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:0\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 64)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:64\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 128)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:128\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 192)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:192\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 256)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:256\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 320)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:320\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 384)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:384\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 448)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:448\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 512)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:512\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 576)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:576\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 640)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:640\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 704)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:704\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 768)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:768\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 832)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:832\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 896)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:896\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 960)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:960\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1024)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1024\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1088)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1088\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1152)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1152\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1216)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1216\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1280)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1280\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1344)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1344\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1408)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1408\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1472)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1472\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1536)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1536\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1600)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1600\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1664)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1664\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1728)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1728\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1792)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1792\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1856)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1856\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1920)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1920\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1984)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1984\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2048)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2048\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2112)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2112\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2176)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2176\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2240)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2240\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2304)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2304\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2368)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2368\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2432)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2432\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2496)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2496\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2560)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2560\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2624)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2624\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2688)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2688\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2752)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2752\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2816)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2816\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2880)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2880\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2944)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2944\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3008)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3008\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3072)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3072\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3136)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3136\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3200)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3200\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3264)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3264\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3328)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3328\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3392)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3392\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3456)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3456\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3520)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3520\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3584)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3584\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3648)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3648\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3712)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3712\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3776)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3776\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3840)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3840\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3904)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3904\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3968)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3968\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 4032)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:4032\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 4096)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:4096\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -196,6 +196,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// choose GEMM implementation here
|
||||
const auto run_blockwise_batch_gemm = [&](auto... Xs) {
|
||||
#if 0
|
||||
return blockwise_batch_gemm.Run(Xs...);
|
||||
#elif 0
|
||||
return blockwise_batch_gemm.Run_asm(Xs...);
|
||||
#else
|
||||
return blockwise_batch_gemm.Run_asm_v2(Xs...);
|
||||
#endif
|
||||
};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
@@ -293,7 +304,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
|
||||
p_wei_register_clipboard);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_batch_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
run_blockwise_batch_gemm(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
@@ -322,7 +333,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
|
||||
p_wei_register_clipboard);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_batch_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
run_blockwise_batch_gemm(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy_reorder.RunStoreRegisterClipboard(
|
||||
@@ -334,7 +345,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_batch_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
run_blockwise_batch_gemm(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
|
||||
@@ -78,22 +78,20 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
// constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
constexpr index_t NBlockWork = mod_conv::integer_divide_ceil(N, NPerBlock);
|
||||
constexpr index_t KBlockWork = mod_conv::integer_divide_ceil(K, KPerBlock);
|
||||
constexpr index_t HBlockWork = mod_conv::integer_divide_ceil(Ho, HoPerBlock);
|
||||
constexpr index_t WBlockWork = mod_conv::integer_divide_ceil(Wo, WoPerBlock);
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
|
||||
const index_t w_block_work_id = itmp / NBlockWork;
|
||||
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
constexpr auto block_work_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<NBlockWork, KBlockWork, HBlockWork, WBlockWork>{});
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
const auto block_work_multi_id = block_work_desc.GetMultiIndex(get_block_1d_id());
|
||||
|
||||
const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock;
|
||||
const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock;
|
||||
const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock;
|
||||
const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
@@ -193,6 +191,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// choose GEMM implementation here
|
||||
const auto run_blockwise_batch_gemm = [&](auto... Xs) {
|
||||
#if 1
|
||||
return blockwise_batch_gemm.Run(Xs...);
|
||||
#elif 0
|
||||
return blockwise_batch_gemm.Run_asm(Xs...);
|
||||
#else
|
||||
return blockwise_batch_gemm.Run_asm_v2(Xs...);
|
||||
#endif
|
||||
};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
@@ -222,7 +231,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_n_c_h_w_global_desc.Get1dIndex(
|
||||
@@ -267,7 +276,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
|
||||
run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
@@ -314,7 +323,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
|
||||
run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user