From 08cbac98cc4b9b941362a8e1812ef9986e9e912f Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 30 Jul 2019 18:20:55 -0500 Subject: [PATCH] added (1x4)x(2x4) threadwise gemm --- .../tensor_operation/blockwise_gemm.hpp | 311 +++++++++--------- .../include/utility/amd_inline_asm.hpp | 9 + ...tion_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp | 74 +---- driver/src/driver.cpp | 4 +- 4 files changed, 166 insertions(+), 232 deletions(-) diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm.hpp index 9e3bc0f1c3..819ecf0c41 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm.hpp @@ -12,17 +12,18 @@ namespace ck { // if following number are power of 2, index calculation shall be greatly reduced: -// MPerThreadSubC, NPerThreadSubC, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster +// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster, +// MLevel1ThreadCluster, NLevel1ThreadCluster template @@ -39,8 +40,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 __device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2() { - constexpr index_t ThreadPerLevel1Cluster = - MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; + constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster * + MLevel1ThreadCluster * NLevel1ThreadCluster; static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n"); @@ -50,8 +51,8 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr index_t M = BlockMatrixA::NCol(); // A is transposed constexpr index_t N = BlockMatrixB::NCol(); - static_assert(M % (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster) == 0 && - N % (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster) == 0, + static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 && + N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0, "wrong! Cannot evenly divide work among\n"); static_assert( @@ -69,26 +70,28 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr index_t M = BlockMatrixA::NCol(); // A is transposed constexpr index_t N = BlockMatrixB::NCol(); - constexpr index_t MRepeat = M / (MPerThreadSubC * MLevel0Cluster * MLevel1Cluster); - constexpr index_t NRepeat = N / (NPerThreadSubC * NLevel0Cluster * NLevel1Cluster); + constexpr index_t MRepeat = + M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster); + constexpr index_t NRepeat = + N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster); return Sequence{}; } __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) { - constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster; + constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster; index_t level1_id = thread_id / ThreadPerLevel0Cluster; - index_t level1_m_id = level1_id / NLevel1Cluster; - index_t level1_n_id = level1_id % NLevel1Cluster; + index_t level1_m_id = level1_id / NLevel1ThreadCluster; + index_t level1_n_id = level1_id % NLevel1ThreadCluster; index_t level0_id = thread_id % ThreadPerLevel0Cluster; - index_t level0_m_id = level0_id / NLevel0Cluster; - index_t level0_n_id = level0_id % NLevel0Cluster; + index_t level0_m_id = level0_id / NLevel0ThreadCluster; + index_t level0_n_id = level0_id % NLevel0ThreadCluster; - constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster; - constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster; + constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster; + constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster; return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC, level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; @@ -99,8 +102,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 { constexpr auto c_thread_mtx = ThreadMatrixC{}; - constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + constexpr index_t MPerLevel1Cluster = + MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster; + constexpr index_t NPerLevel1Cluster = + NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster; index_t m_repeat = m_in_c / MPerThreadSubC; index_t n_repeat = n_in_c / NPerThreadSubC; @@ -139,8 +144,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + constexpr index_t MPerLevel1Cluster = + MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster; + constexpr index_t NPerLevel1Cluster = + NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster; // assertion for inline asm static_assert(is_same{} && is_same{} && @@ -184,6 +191,123 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); } + + __device__ void Run_amd_asm_v2(const float* __restrict__ p_a_block, + const float* __restrict__ p_b_block, + float* __restrict__ p_c_thread) const + { + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + constexpr index_t M = a_block_mtx.NCol(); + constexpr index_t N = b_block_mtx.NCol(); + constexpr index_t K = a_block_mtx.NRow(); + + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); + + // thread A, B for GEMM + constexpr auto a_thread_mtx = + make_ConstantMatrixDescriptor_packed(Number{}, Number{}); + + constexpr auto b_thread_mtx = + make_ConstantMatrixDescriptor_packed(Number{}, Number{}); + + float p_a_thread[a_thread_mtx.GetElementSpace()]; + float p_b_thread[b_thread_mtx.GetElementSpace()]; + + constexpr index_t MThreadCluster = MLevel0ThreadCluster * MLevel1ThreadCluster; + constexpr index_t NThreadCluster = NLevel0ThreadCluster * NLevel1ThreadCluster; + + constexpr index_t MDataCluster = M / MPerThreadSubC; + constexpr index_t NDataCluster = N / NPerThreadSubC; + + constexpr index_t MRepeat = MDataCluster / MThreadCluster; + constexpr index_t NRepeat = NDataCluster / NThreadCluster; + + // assertion for inline asm + static_assert((MPerThreadSubC == 4 && NPerThreadSubC == 4 && MRepeat == 2 && NRepeat == 2 && + KPerThreadLoop == 1) || + (MPerThreadSubC == 2 && NPerThreadSubC == 4 && MRepeat == 2 && + NRepeat == 2 && KPerThreadLoop == 1), + "Run_amd_asm cannot deal with this GEMM shape yet"); + + static_assert(DataPerReadA == MPerThreadSubC && DataPerReadB == NPerThreadSubC, + "wrong! Run_amd_asm doesn't support this config"); + + if(MPerThreadSubC == 4 && NPerThreadSubC == 4 && MRepeat == 2 && NRepeat == 2 && + KPerThreadLoop == 1) + { + using float4_type = vector_type::MemoryType; + + float4_type* reg_a = reinterpret_cast(p_a_thread); + float4_type* reg_b = reinterpret_cast(p_b_thread); + float4_type* reg_c = reinterpret_cast(p_c_thread); + + const float4_type* p_a = + reinterpret_cast(&p_a_block[mMyThreadOffsetA]); + const float4_type* p_b = + reinterpret_cast(&p_b_block[mMyThreadOffsetB]); + + reg_a[0] = p_a[0]; + reg_b[0] = p_b[0]; + reg_b[1] = p_b[NThreadCluster]; + reg_a[1] = p_a[MThreadCluster]; + outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); + outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); +#pragma unroll + for(index_t k = 1; k < K; ++k) + { + reg_a[0] = p_a[k * MDataCluster]; + outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); + reg_b[0] = p_b[k * NDataCluster]; + outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); + reg_b[1] = p_b[k * NDataCluster + NThreadCluster]; + reg_a[1] = p_a[k * MDataCluster + MThreadCluster]; + outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); + outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); + } + outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); + outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); + } + else if(MPerThreadSubC == 2 && NPerThreadSubC == 4 && MRepeat == 2 && NRepeat == 2 && + KPerThreadLoop == 1) + { + using float2_type = vector_type::MemoryType; + using float4_type = vector_type::MemoryType; + + float2_type* reg_a = reinterpret_cast(p_a_thread); + float4_type* reg_b = reinterpret_cast(p_b_thread); + float4_type* reg_c = reinterpret_cast(p_c_thread); + + const float2_type* p_a = + reinterpret_cast(&p_a_block[mMyThreadOffsetA]); + const float4_type* p_b = + reinterpret_cast(&p_b_block[mMyThreadOffsetB]); + + reg_a[0] = p_a[0]; + reg_b[0] = p_b[0]; + reg_b[1] = p_b[NThreadCluster]; + reg_a[1] = p_a[MThreadCluster]; + outerProduct2x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2]); + outerProduct2x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3]); +#pragma unroll + for(index_t k = 1; k < K; ++k) + { + reg_a[0] = p_a[k * MDataCluster]; + outerProduct2x4(reg_a[1], reg_b[0], reg_c[4], reg_c[6]); + reg_b[0] = p_b[k * NDataCluster]; + outerProduct2x4(reg_a[1], reg_b[1], reg_c[5], reg_c[7]); + reg_b[1] = p_b[k * NDataCluster + NThreadCluster]; + reg_a[1] = p_a[k * MDataCluster + MThreadCluster]; + outerProduct2x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2]); + outerProduct2x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3]); + } + outerProduct2x4(reg_a[1], reg_b[0], reg_c[4], reg_c[6]); + outerProduct2x4(reg_a[1], reg_b[1], reg_c[5], reg_c[7]); + } + } #endif template @@ -220,8 +344,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + constexpr index_t MPerLevel1Cluster = + MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster; + constexpr index_t NPerLevel1Cluster = + NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster; constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC; @@ -273,141 +399,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 } } - template - __device__ void RunRegisterDoubleBuffer_source(FloatA* const p_a_block, - FloatB* const p_b_block, - FloatC* p_c_thread) const - { - constexpr auto True = integral_constant{}; - constexpr auto False = integral_constant{}; - - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr index_t K = a_block_mtx.NRow(); - - constexpr index_t MPerThread = c_thread_mtx.NRow(); - constexpr index_t NPerThread = c_thread_mtx.NCol(); - - // thread A, B for GEMM - constexpr auto a_thread_mtx = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - constexpr auto b_thread_mtx = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - // thread A-sub, B-sub for copy - constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - // register - FloatA p_a_thread_0[a_thread_mtx.GetElementSpace()]; - FloatB p_b_thread_0[b_thread_mtx.GetElementSpace()]; - - FloatA p_a_thread_1[a_thread_mtx.GetElementSpace()]; - FloatB p_b_thread_1[b_thread_mtx.GetElementSpace()]; - - constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - - constexpr index_t MRepeat = MPerThread / MPerThreadSubC; - constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - -// preload A, B -#pragma unroll - for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { // copy A-sub to form A - threadwise_matrix_copy(a_block_mtx, - p_a_block + mMyThreadOffsetA + m_repeat * MPerLevel1Cluster, - a_thread_sub_mtx, - p_a_thread_0 + m_repeat * MPerThreadSubC, - a_thread_sub_mtx.GetLengths(), - Number{}); - } - -#pragma unroll - for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { // copy B-sub to form B - threadwise_matrix_copy(b_block_mtx, - p_b_block + mMyThreadOffsetB + n_repeat * NPerLevel1Cluster, - b_thread_sub_mtx, - p_b_thread_0 + n_repeat * NPerThreadSubC, - b_thread_sub_mtx.GetLengths(), - Number{}); - } - - bool even_loop = true; - -#pragma unroll - for(index_t k_begin = 0; k_begin + KPerThreadLoop < K; - k_begin += KPerThreadLoop, even_loop = !even_loop) - { // loop over k - FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1; - FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1; - - FloatA* p_a_thread_next = even_loop ? p_a_thread_1 : p_a_thread_0; - FloatB* p_b_thread_next = even_loop ? p_b_thread_1 : p_b_thread_0; - -// preload next A, B -#pragma unroll - for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { // copy A-sub to form A - threadwise_matrix_copy(a_block_mtx, - p_a_block + mMyThreadOffsetA + - (k_begin + 1) * a_block_mtx.RowStride() + - m_repeat * MPerLevel1Cluster, - a_thread_sub_mtx, - p_a_thread_next + m_repeat * MPerThreadSubC, - a_thread_sub_mtx.GetLengths(), - Number{}); - } - -#pragma unroll - for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { // copy B-sub to form B - threadwise_matrix_copy(b_block_mtx, - p_b_block + mMyThreadOffsetB + - (k_begin + 1) * b_block_mtx.RowStride() + - n_repeat * NPerLevel1Cluster, - b_thread_sub_mtx, - p_b_thread_next + n_repeat * NPerThreadSubC, - b_thread_sub_mtx.GetLengths(), - Number{}); - } - - // C = A * B - threadwise_gemm(a_thread_mtx, - True, - p_a_thread_now, - b_thread_mtx, - False, - p_b_thread_now, - c_thread_mtx, - False, - p_c_thread); - } - - // last loop - { - FloatA* p_a_thread_now = even_loop ? p_a_thread_0 : p_a_thread_1; - FloatB* p_b_thread_now = even_loop ? p_b_thread_0 : p_b_thread_1; - - // C = A * B - threadwise_gemm(a_thread_mtx, - True, - p_a_thread_now, - b_thread_mtx, - False, - p_b_thread_now, - c_thread_mtx, - False, - p_c_thread); - } - } template __device__ void Run(const FloatA* __restrict__ p_a_block, const FloatB* __restrict__ p_b_block, @@ -415,7 +406,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 { #if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM - Run_amd_asm(p_a_block, p_b_block, p_c_thread); + Run_amd_asm_v2(p_a_block, p_b_block, p_c_thread); #else Run_source(p_a_block, p_b_block, p_c_thread); #endif diff --git a/composable_kernel/include/utility/amd_inline_asm.hpp b/composable_kernel/include/utility/amd_inline_asm.hpp index e7e20808e1..0a17b4bd3a 100644 --- a/composable_kernel/include/utility/amd_inline_asm.hpp +++ b/composable_kernel/include/utility/amd_inline_asm.hpp @@ -105,6 +105,15 @@ __device__ void outerProduct1x4(const float& a, outerProduct1x4(&a, reinterpret_cast(&b), reinterpret_cast(&c)); } +__device__ void outerProduct2x4(const vector_type::MemoryType& a, + const vector_type::MemoryType& b, + vector_type::MemoryType& c0, + vector_type::MemoryType& c1) +{ + outerProduct1x4(a.x, b, c0); + outerProduct1x4(a.y, b, c1); +} + __device__ void outerProduct4x4(const vector_type::MemoryType& a, const vector_type::MemoryType& b, vector_type::MemoryType& c0, diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index 6b7e1c4451..eb132cd331 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -60,6 +60,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t B = (N * Ho * Wo) / (N1 * N2); #if 0 + // each thread hold 64 data constexpr index_t BlockSize = 256; constexpr index_t BPerBlock = 16; @@ -94,20 +95,21 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; #elif 1 + // each thread hold 32 data constexpr index_t BlockSize = 256; constexpr index_t BPerBlock = 16; constexpr index_t KPerBlock = 64; constexpr index_t EPerBlock = 8; - constexpr index_t GemmMPerThreadSubC = 4; + constexpr index_t GemmMPerThreadSubC = 2; constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4; constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadA = 2; constexpr index_t GemmDataPerReadB = 4; using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; @@ -127,74 +129,6 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 - constexpr index_t BlockSize = 256; - - constexpr index_t BPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 4, 1>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 4, 4>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 4; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 1; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 1 - constexpr index_t BlockSize = 256; - - constexpr index_t BPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 2, 2>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 8, 2>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 2; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; #endif constexpr index_t GridSize = diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index 02abdae973..b118a55ae7 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -72,11 +72,11 @@ int main(int argc, char* argv[]) using namespace ck; #if 0 - constexpr index_t N = 256; + constexpr index_t N = 64; constexpr index_t C = 1536; constexpr index_t HI = 8; constexpr index_t WI = 8; - constexpr index_t K = 512; + constexpr index_t K = 256; constexpr index_t Y = 1; constexpr index_t X = 1;