diff --git a/driver/driver.hip.cpp b/driver/driver.hip.cpp index cb45c89f22..ab08356262 100644 --- a/driver/driver.hip.cpp +++ b/driver/driver.hip.cpp @@ -580,13 +580,16 @@ int main(int argc, char* argv[]) #if 0 in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); +#elif 0 + in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); #elif 0 in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); #elif 1 in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); -#elif 1 +#elif 0 in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); auto gen_wei = [](auto... is) { diff --git a/src/include/blockwise_batched_gemm.hip.hpp b/src/include/blockwise_batched_gemm.hip.hpp index 55c41c4417..7e6e037ba1 100644 --- a/src/include/blockwise_batched_gemm.hip.hpp +++ b/src/include/blockwise_batched_gemm.hip.hpp @@ -289,9 +289,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 const FloatB* __restrict__ p_b_block, FloatC* __restrict__ p_c_thread) const { - constexpr auto True = integral_constant{}; - constexpr auto False = integral_constant{}; - constexpr auto a_block_mtx = BlockMatrixA{}; constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto c_thread_mtx = ThreadMatrixC{}; @@ -371,6 +368,102 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); } + + template + __device__ void Run_asm_v2(const FloatA* __restrict__ p_a_block, + const FloatB* __restrict__ p_b_block, + FloatC* __restrict__ p_c_thread) const + { + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + constexpr index_t M = a_block_mtx.NCol(); + constexpr index_t N = b_block_mtx.NCol(); + constexpr index_t K = a_block_mtx.NRow(); // A is transposed + + constexpr index_t MPerThread = c_thread_mtx.NRow(); + constexpr index_t NPerThread = c_thread_mtx.NCol(); + + // thread A, B for GEMM + // A is transposed, b is not + constexpr auto a_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + constexpr auto b_thread_mtx = + make_ConstantMatrixDescriptor(Number{}, Number{}); + + // thread A-sub, B-sub for copy + constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( + Number{}, Number{}, Number{}); + + FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; + FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; + + constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; + constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; + + // assertion for inline asm + static_assert(is_same::value && is_same::value && + is_same::value, + "Run_asm only deal with float\n"); + + static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 && + MPerThread == 8 && NPerThread == 8, + "Run_asm cannot deal with this GEMM shape yet\n"); + + static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_asm only do float4 read\n"); + + static_assert( + BlockMatrixStrideA == 0 && BatchPerThread == 1, + "Run_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == 1 for now\n"); + + using Float4 = vector_type::MemoryType; + + Float4* reg_a = (Float4*)(p_a_thread); + Float4* reg_b = (Float4*)(p_b_thread); + Float4* reg_c = (Float4*)(p_c_thread); + + void* a_lds_loc = (void*)(p_a_block + mMyThreadOffsetA); + void* b_lds_loc = (void*)(p_b_block + mMyThreadOffsetB); + + constexpr index_t a_lds_row_stride = sizeof(Float) * M; + constexpr index_t b_lds_row_stride = sizeof(Float) * N; + constexpr index_t a_lds_cluster_col_stride = sizeof(Float) * MPerLevel1Cluster; + constexpr index_t b_lds_cluster_col_stride = sizeof(Float) * NPerLevel1Cluster; + + ds_read_b128(reg_a[0], a_lds_loc, 0); + ds_read_b128(reg_b[0], b_lds_loc, 0); + ds_read_b128(reg_b[1], b_lds_loc, b_lds_cluster_col_stride); + ds_read_b128(reg_a[1], a_lds_loc, a_lds_cluster_col_stride); + lgkmcnt(2); + outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); + lgkmcnt(1); + outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); + +#pragma unroll + for(index_t k = 1; k < K; ++k) + { + ds_read_b128(reg_a[0], a_lds_loc, k * a_lds_row_stride); + lgkmcnt(1); + outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); + ds_read_b128(reg_b[0], b_lds_loc, k * b_lds_row_stride); + outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); + ds_read_b128(reg_b[1], b_lds_loc, b_lds_cluster_col_stride + k * b_lds_row_stride); + ds_read_b128(reg_a[1], a_lds_loc, a_lds_cluster_col_stride + k * a_lds_row_stride); + lgkmcnt(2); + outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); + lgkmcnt(1); + outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); + } + + lgkmcnt(0); + outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); + outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); + } #endif template diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp index 669ba69deb..c1b07b19f2 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp @@ -273,7 +273,13 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn __syncthreads(); +#if 1 blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); +#elif 0 + blockwise_batch_gemm.Run_asm(p_wei_block, p_in_block, p_out_thread); +#elif 0 + blockwise_batch_gemm.Run_asm_v2(p_wei_block, p_in_block, p_out_thread); +#endif __syncthreads(); }