diff --git a/src/include/amd_inline_asm.hip.hpp b/src/include/amd_inline_asm.hip.hpp index 38ac970981..4a8be241ba 100644 --- a/src/include/amd_inline_asm.hip.hpp +++ b/src/include/amd_inline_asm.hip.hpp @@ -201,7 +201,7 @@ __device__ void ds_read_b128(vector_type::MemoryType& r, void* lds, in if(offset == 0) { asm volatile("\n \ - ds_read_b128 %0, %1 offset:0 \n \ + ds_read_b128 %0, %1 \n \ " : "=v"(r) : "v"(__to_local(lds))); @@ -350,6 +350,14 @@ __device__ void ds_read_b128(vector_type::MemoryType& r, void* lds, in : "=v"(r) : "v"(__to_local(lds))); } + else if(offset == 2432) + { + asm volatile("\n \ + ds_read_b128 %0, %1 offset:2432 \n \ + " + : "=v"(r) + : "v"(__to_local(lds))); + } else if(offset == 2560) { asm volatile("\n \ @@ -358,6 +366,14 @@ __device__ void ds_read_b128(vector_type::MemoryType& r, void* lds, in : "=v"(r) : "v"(__to_local(lds))); } + else if(offset == 2688) + { + asm volatile("\n \ + ds_read_b128 %0, %1 offset:2688 \n \ + " + : "=v"(r) + : "v"(__to_local(lds))); + } else if(offset == 2816) { asm volatile("\n \ @@ -366,6 +382,14 @@ __device__ void ds_read_b128(vector_type::MemoryType& r, void* lds, in : "=v"(r) : "v"(__to_local(lds))); } + else if(offset == 2944) + { + asm volatile("\n \ + ds_read_b128 %0, %1 offset:2944 \n \ + " + : "=v"(r) + : "v"(__to_local(lds))); + } else if(offset == 3072) { asm volatile("\n \ @@ -374,6 +398,14 @@ __device__ void ds_read_b128(vector_type::MemoryType& r, void* lds, in : "=v"(r) : "v"(__to_local(lds))); } + else if(offset == 3200) + { + asm volatile("\n \ + ds_read_b128 %0, %1 offset:3200 \n \ + " + : "=v"(r) + : "v"(__to_local(lds))); + } else if(offset == 3328) { asm volatile("\n \ @@ -382,6 +414,14 @@ __device__ void ds_read_b128(vector_type::MemoryType& r, void* lds, in : "=v"(r) : "v"(__to_local(lds))); } + else if(offset == 3456) + { + asm volatile("\n \ + ds_read_b128 %0, %1 offset:3456 \n \ + " + : "=v"(r) + : "v"(__to_local(lds))); + } else if(offset == 3584) { asm volatile("\n \ @@ -390,6 +430,14 @@ __device__ void ds_read_b128(vector_type::MemoryType& r, void* lds, in : "=v"(r) : "v"(__to_local(lds))); } + else if(offset == 3712) + { + asm volatile("\n \ + ds_read_b128 %0, %1 offset:3712 \n \ + " + : "=v"(r) + : "v"(__to_local(lds))); + } else if(offset == 3840) { asm volatile("\n \ @@ -398,6 +446,14 @@ __device__ void ds_read_b128(vector_type::MemoryType& r, void* lds, in : "=v"(r) : "v"(__to_local(lds))); } + else if(offset == 3968) + { + asm volatile("\n \ + ds_read_b128 %0, %1 offset:3968 \n \ + " + : "=v"(r) + : "v"(__to_local(lds))); + } else if(offset == 4096) { asm volatile("\n \ diff --git a/src/include/blockwise_batched_gemm.hip.hpp b/src/include/blockwise_batched_gemm.hip.hpp index 7e6e037ba1..b50acfe0d4 100644 --- a/src/include/blockwise_batched_gemm.hip.hpp +++ b/src/include/blockwise_batched_gemm.hip.hpp @@ -293,8 +293,6 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 constexpr auto b_block_mtx = BlockMatrixB{}; constexpr auto c_thread_mtx = ThreadMatrixC{}; - constexpr index_t M = a_block_mtx.NCol(); - constexpr index_t N = b_block_mtx.NCol(); constexpr index_t K = a_block_mtx.NRow(); // A is transposed constexpr index_t MPerThread = c_thread_mtx.NRow(); @@ -344,24 +342,26 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 reg_a[0] = *reinterpret_cast(&p_a_block[mMyThreadOffsetA]); reg_b[0] = *reinterpret_cast(&p_b_block[mMyThreadOffsetB]); - reg_b[1] = - *reinterpret_cast(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]); - reg_a[1] = - *reinterpret_cast(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]); + reg_b[1] = *reinterpret_cast( + &p_b_block[b_block_mtx.Get1dIndex(0, NPerLevel1Cluster) + mMyThreadOffsetB]); + reg_a[1] = *reinterpret_cast( + &p_a_block[a_block_mtx.Get1dIndex(0, MPerLevel1Cluster) + mMyThreadOffsetA]); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); #pragma unroll for(index_t k = 1; k < K; ++k) { - reg_a[0] = *reinterpret_cast(&p_a_block[mMyThreadOffsetA + k * M]); + reg_a[0] = *reinterpret_cast( + &p_a_block[a_block_mtx.Get1dIndex(k, 0) + mMyThreadOffsetA]); outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); - reg_b[0] = *reinterpret_cast(&p_b_block[mMyThreadOffsetB + k * N]); + reg_b[0] = *reinterpret_cast( + &p_b_block[b_block_mtx.Get1dIndex(k, 0) + mMyThreadOffsetB]); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); reg_b[1] = *reinterpret_cast( - &p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]); + &p_b_block[b_block_mtx.Get1dIndex(k, NPerLevel1Cluster) + mMyThreadOffsetB]); reg_a[1] = *reinterpret_cast( - &p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]); + &p_a_block[a_block_mtx.Get1dIndex(k, MPerLevel1Cluster) + mMyThreadOffsetA]); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); } @@ -430,10 +430,10 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 void* a_lds_loc = (void*)(p_a_block + mMyThreadOffsetA); void* b_lds_loc = (void*)(p_b_block + mMyThreadOffsetB); - constexpr index_t a_lds_row_stride = sizeof(Float) * M; - constexpr index_t b_lds_row_stride = sizeof(Float) * N; - constexpr index_t a_lds_cluster_col_stride = sizeof(Float) * MPerLevel1Cluster; - constexpr index_t b_lds_cluster_col_stride = sizeof(Float) * NPerLevel1Cluster; + constexpr index_t a_lds_row_stride = sizeof(float) * a_block_mtx.RowStride(); + constexpr index_t b_lds_row_stride = sizeof(float) * b_block_mtx.RowStride(); + constexpr index_t a_lds_cluster_col_stride = sizeof(float) * MPerLevel1Cluster; + constexpr index_t b_lds_cluster_col_stride = sizeof(float) * NPerLevel1Cluster; ds_read_b128(reg_a[0], a_lds_loc, 0); ds_read_b128(reg_b[0], b_lds_loc, 0); diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp index c1b07b19f2..4a3459119d 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp @@ -213,7 +213,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn // set threadwise output tensor to 0 threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); -#if 0 +#if 1 const Float* p_in_global_block_offset = p_in_global + in_c_h_w_n_global_desc.Get1dIndex( @@ -241,7 +241,13 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn __syncthreads(); +#if 1 blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); +#elif 0 + blockwise_batch_gemm.Run_asm(p_wei_block, p_in_block, p_out_thread); +#elif 1 + blockwise_batch_gemm.Run_asm_v2(p_wei_block, p_in_block, p_out_thread); +#endif __syncthreads(); } @@ -277,7 +283,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); #elif 0 blockwise_batch_gemm.Run_asm(p_wei_block, p_in_block, p_out_thread); -#elif 0 +#elif 1 blockwise_batch_gemm.Run_asm_v2(p_wei_block, p_in_block, p_out_thread); #endif diff --git a/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp index 640683fe1e..5595d596e9 100644 --- a/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp +++ b/src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp @@ -293,8 +293,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, p_wei_register_clipboard); - // LDS double buffer: GEMM on current data - blockwise_batch_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); +// LDS double buffer: GEMM on current data +#if 1 + blockwise_batch_gemm.Run +#elif 0 + blockwise_batch_gemm.Run_asm +#else + blockwise_batch_gemm.Run_asm_v2 +#endif + (p_wei_block_now, p_in_block_now, p_out_thread); // LDS double buffer: store next data to LDS blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, @@ -321,8 +328,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset, p_wei_register_clipboard); - // LDS double buffer: GEMM on current data - blockwise_batch_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); +// LDS double buffer: GEMM on current data +#if 1 + blockwise_batch_gemm.Run +#elif 0 + blockwise_batch_gemm.Run_asm +#else + blockwise_batch_gemm.Run_asm_v2 +#endif + (p_wei_block_double, p_in_block_double, p_out_thread); // LDS double buffer: store next data to LDS blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, @@ -333,10 +347,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn // odd iteration __syncthreads(); - // LDS double buffer: GEMM on current data - blockwise_batch_gemm.Run(p_wei_block_double + wei_block_space, - p_in_block_double + in_block_space, - p_out_thread); +// LDS double buffer: GEMM on current data +#if 1 + blockwise_batch_gemm.Run +#elif 0 + blockwise_batch_gemm.Run_asm +#else + blockwise_batch_gemm.Run_asm_v2 +#endif + (p_wei_block_double + wei_block_space, + p_in_block_double + in_block_space, + p_out_thread); } } }