diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp index eeca3acfe6..cccfac88e8 100644 --- a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp @@ -66,32 +66,36 @@ struct MultiplyMultiply void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) { const int NRepeat = 1; const int KRepeat = 4; + const int NWave = 4; const int KLane = 2; - const int NLane = 128; + const int NLane = 32; const int KPack = 16; - int N0 = N / (NRepeat * NLane); + int N0 = N / (NRepeat * NLane * NWave); int K0 = K / (KRepeat * KLane * KPack); int tempn, tempk; for (int n = 0; n < N; ++n) { for (int k = 0; k < K; ++k) { - int n0 = n / (NRepeat * NLane); + int n0 = n / (NRepeat * NLane * NWave); int k0 = k / (KRepeat * KLane * KPack); - tempn = n % (NRepeat * NLane); + tempn = n % (NRepeat * NLane * NWave); tempk = k % (KRepeat * KLane * KPack); - int n1 = tempn / NLane; + int n1 = tempn / (NLane * NWave); int k1 = tempk / (KLane * KPack); - int n2 = n1 % NLane; + tempn = tempn % (NLane * NWave); tempk = tempk % (KLane * KPack); + int n2 = tempn / NLane; int k2 = tempk / KPack; + int n3 = tempn % NLane; int k3 = tempk % KPack; - int outputIndex = n0 * KPack * NLane * KLane * KRepeat * NRepeat * K0 - + k0 * KPack * NLane * KLane * KRepeat * NRepeat - + n1 * KPack * NLane * KLane * KRepeat - + k1 * KPack * NLane * KLane + int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0 + + k0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat + + n1 * KPack * NLane * KLane * NWave * KRepeat + + k1 * KPack * NLane * KLane * NWave + + n2 * KPack * NLane * KLane + k2 * KPack * NLane - + n2 * KPack + + n3 * KPack + k3; dst[outputIndex] = src[n * K + k]; @@ -269,7 +273,7 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 1}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_btype = diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp index 9e85b7c62f..6eb0812347 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -357,7 +357,7 @@ struct BlockwiseGemmXdlops_pipeline_v3{}]; // if(threadIdx.x==0) { - // printf("%f, %f; ", type_convert(a_thread_vec.template AsType()(ik)), ype_convert(b_thread_vec.template AsType()(ik))); + // printf("%f, %f; ", type_convert(a_thread_vec.template AsType()(ik)), type_convert(b_thread_vec.template AsType()(ik))); // } }); @@ -451,6 +451,11 @@ struct BlockwiseGemmXdlops_pipeline_v3{}); + + static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -462,6 +467,48 @@ struct BlockwiseGemmXdlops_pipeline_v3{}]; }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf1, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec = + b_blockwise_copy.template GetSrcThreadScratchIdx, Number<1>{}>(); + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + }); using mfma_input_type = typename vector_type::type;