diff --git a/src/include/blockwise_gemm.hip.hpp b/src/include/blockwise_gemm.hip.hpp index d540dd8f4f..73d6732874 100644 --- a/src/include/blockwise_gemm.hip.hpp +++ b/src/include/blockwise_gemm.hip.hpp @@ -361,10 +361,10 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 // thread A-sub, B-sub for copy constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); + Number{}, Number{}, Number{}); constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); + Number{}, Number{}, Number{}); float p_thread[a_thread_mtx.GetElementSpace() + b_thread_mtx.GetElementSpace()]; @@ -377,66 +377,42 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - // auto a_src_index = a_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetA; - // auto b_src_index = b_block_mtx.Get1dIndex(k_begin, 0) + mMyThreadOffsetB; Float4* reg_a = (Float4*)(p_a_thread); Float4* reg_b = (Float4*)(p_b_thread); Float4* reg_c = (Float4*)(p_c_thread); void* a_loc = (void*)(p_a_block + mMyThreadOffsetA); void* b_loc = (void*)(p_b_block + mMyThreadOffsetB); - // loop over k - int k_chunk = K; - //for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop * k_chunk) - index_t k_begin = 0; + + int lds_a_block_off = sizeof(Float) * M; + int lds_b_block_off = sizeof(Float) * N; + int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float); + int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float); + ds_read_b128(reg_a[0], a_loc, 0); + ds_read_b128(reg_b[0], b_loc, 0); + ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1); + ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1); + lgkmcnt(2); + outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); + lgkmcnt(1); + outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); + lgkmcnt(0); +#pragma unroll + for(int k_i = 1; k_i < K; k_i++) { - -#if 0 - ds_read_b128(reg_a[0], a_loc, 0); - ds_read_b128(reg_a[1], a_loc, 256); - ds_read_b128(reg_b[0], b_loc, 0); - ds_read_b128(reg_b[1], b_loc, 128); - - lgkmcnt(0); - - outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); - outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); + ds_read_b128(reg_a[0], a_loc, k_i * lds_a_block_off); outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); + ds_read_b128(reg_b[0], b_loc, k_i * lds_b_block_off); outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); -#else - int k = k_begin; - int lds_a_block_off = sizeof(Float) * M; - int lds_b_block_off = sizeof(Float) * N; - int lds_a_block_off_1 = MPerLevel1Cluster * sizeof(Float); - int lds_b_block_off_1 = NPerLevel1Cluster * sizeof(Float); - ds_read_b128(reg_a[0], a_loc, k * lds_a_block_off); - ds_read_b128(reg_b[0], b_loc, k * lds_b_block_off); - ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k * lds_b_block_off); - ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k * lds_a_block_off); + ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k_i * lds_b_block_off); + ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k_i * lds_a_block_off); lgkmcnt(2); outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); lgkmcnt(1); outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); lgkmcnt(0); -#pragma unroll - for(int i = 0; i < k_chunk - 1; i++) - { - k = k + 1; - ds_read_b128(reg_a[0], a_loc, k * lds_a_block_off); - outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); - ds_read_b128(reg_b[0], b_loc, k * lds_b_block_off); - outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); - ds_read_b128(reg_b[1], b_loc, lds_b_block_off_1 + k * lds_b_block_off); - ds_read_b128(reg_a[1], a_loc, lds_a_block_off_1 + k * lds_a_block_off); - lgkmcnt(2); - outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); - lgkmcnt(1); - outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); - lgkmcnt(0); - } - outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); - outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); -#endif } + outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); + outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); } template