From 49bac8cef77e7e145344bdfe66d362aacc31ecf4 Mon Sep 17 00:00:00 2001 From: mtgu0705 Date: Mon, 17 Feb 2025 16:04:28 +0800 Subject: [PATCH] Added b preshuffle pipeline v3 support. --- .../gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp | 23 +++++- ...e_gemm_pipeline_xdlops_b_preshuffle_v3.hpp | 82 +++++++++++++++++-- 2 files changed, 96 insertions(+), 9 deletions(-) diff --git a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp index 45151f9e2b..544438bccb 100644 --- a/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp +++ b/example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp @@ -28,9 +28,9 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa static constexpr bool PermuteA = false; static constexpr bool PermuteB = false; -static constexpr ck::index_t KPerBlock = 128; // clang-format off +#if 0 using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle< ALayout, BLayout, CLayout, @@ -38,7 +38,7 @@ using DeviceGemmV2Instance = AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, - KPerBlock, 16, 32, + 256, 16, 32, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, @@ -47,7 +47,26 @@ using DeviceGemmV2Instance = 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 4, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F8, F8, PermuteA, PermuteB>; + +#else +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle< + ALayout, BLayout, CLayout, + ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, + 256, 256, + 128, 16, 32, + 32, 32, + 4, 4, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 32, 32, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, F8, F8, PermuteA, PermuteB>; +#endif // clang-format on template diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp index 0f2b688a8e..6939429599 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp @@ -510,10 +510,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + auto b_thread_dequant_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_dequant_bufs; constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); // Global prefetch A1 B1 @@ -545,6 +548,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I0)); // Initialize C c_thread_buf.Clear(); @@ -594,9 +604,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}]; b_thread_vec.template AsType()(ik) = - b_thread_bufs[mfma_reg_buf] - [Number{}]; + b_thread_dequant_bufs[mfma_reg_buf] + [Number{}]; }); using mfma_input_type = @@ -633,6 +643,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(local_read_buf)); } else { @@ -652,6 +669,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(mfma_reg_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(mfma_reg_buf)); } HotLoopScheduler(m0); @@ -691,7 +715,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}]; b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; }); @@ -720,6 +744,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}, I0, I0, k0, I0, I0), a_thread_buf); }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I1)); } else { @@ -732,6 +763,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}, I0, I0, k0, I0, I0), a_thread_buf); }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I0)); } EpilogueScheduler_1(m0); @@ -748,7 +786,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}]; b_thread_vec.template AsType()(ik) = - b_thread_bufs[I1][Number{}]; }); @@ -776,6 +814,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}, I0, I0, k0, I0, I0), a_thread_buf); }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I1)); EpilogueScheduler_2(); } @@ -797,7 +842,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}]; b_thread_vec.template AsType()(ik) = - b_thread_bufs[I0][Number{}]; }); @@ -823,6 +868,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}, I0, I0, k0, I0, I0), a_thread_buf); }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I0)); EpilogueScheduler_2(); } @@ -855,6 +907,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + KPack>; + + const PassThrough b_element_op{}; + BThreadDequantCopy b_thread_dequant_copy_{b_element_op}; }; } // namespace ck