From 4111d2fbfdf12c57716ba9e440754f0d48dde0f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 13 Dec 2024 21:08:35 +0100 Subject: [PATCH] Add SplitK support into Batched GEMM V3 (#1729) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add bmm api * add bf16 multi_d * add ckProfiler for bf16 * add ckProfiler files * add more instance; fixed 64bit index issue * fixed naming * enabled batched Ds * use long_index for ds offsets * clean * add bmm fp8 ckProfiler * Update example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp Co-authored-by: Bartłomiej Kocot * Update example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp Co-authored-by: Bartłomiej Kocot * Update example/24_batched_gemm/run_batched_gemm_example_rowwise.inc Co-authored-by: Bartłomiej Kocot * Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp Co-authored-by: Bartłomiej Kocot * Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp Co-authored-by: Bartłomiej Kocot * Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp Co-authored-by: Bartłomiej Kocot * Update profiler/src/profile_gemm_universal_batched.cpp Co-authored-by: Bartłomiej Kocot * Update profiler/include/profiler/profile_gemm_universal_batched_impl.hpp Co-authored-by: Bartłomiej Kocot * clean * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * refactor batch offset func * add splitk suppport into bmm_v3 * clean * clean * format * fixed * fix --------- Co-authored-by: Jing Zhang Co-authored-by: zjing14 [ROCm/composable_kernel commit: 4d8fce33dddfc003432ae06848f6416a9d5d5e2f] --- .../batched_gemm_xdl_bf16_v3.cpp | 4 +- .../device/device_batched_gemm_multi_d.hpp | 3 +- ...atched_gemm_multiple_d_xdl_cshuffle_v3.hpp | 45 ++++-- .../gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp | 16 +- ..._xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp | 3 + ...gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp | 2 + .../profile_gemm_universal_batched_impl.hpp | 148 ++++++++++-------- .../src/profile_gemm_universal_batched.cpp | 20 +-- 8 files changed, 137 insertions(+), 104 deletions(-) diff --git a/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp index fa8b752185..548500518f 100644 --- a/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp +++ b/example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp @@ -78,14 +78,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD 2, // ABlockTransferSrcVectorDim 8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferDstScalarPerVector_AK1 - 1, // ABlockLdsExtraM + 0, // ABlockLdsExtraM S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim 8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferDstScalarPerVector_BK1 - 1, // BBlockLdsExtraN + 0, // BBlockLdsExtraN 1, // CShuffleMXdlPerWavePerShuffle 1, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp index 58c0288e8f..8fb4a71f55 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp @@ -89,7 +89,8 @@ struct DeviceBatchedGemmV2MultiD : public BaseOperator index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) = 0; + CDEElementwiseOperation cde_element_op, + index_t KBatch) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp index 314ecdf76e..5f5bea4f86 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -41,12 +41,15 @@ __global__ void __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + // populate pointer, desc for Ds static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { // D pointer @@ -54,8 +57,8 @@ __global__ void }); GridwiseGemm::template Run( - karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, karg.p_ds_grid, karg.p_c_grid + c_batch_offset, p_shared, @@ -87,12 +90,15 @@ __global__ void __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + // populate pointer, desc for Ds static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { // D pointer @@ -100,8 +106,8 @@ __global__ void }); GridwiseGemm::template Run_2Lds( - karg.p_a_grid + a_batch_offset, - karg.p_b_grid + b_batch_offset, + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, karg.p_ds_grid, karg.p_c_grid + c_batch_offset, p_shared_0, @@ -303,7 +309,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 index_t Batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, - CElementwiseOperation c_element_op_) + CElementwiseOperation c_element_op_, + index_t KBatch_) : GridwiseGemm::Argument{p_a_grid_, p_b_grid_, p_ds_grid_, @@ -315,7 +322,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 StrideB_, StrideDs_, StrideE_, - 1, + KBatch_, a_element_op_, b_element_op_, c_element_op_}, @@ -336,13 +343,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 arg.Print(); } - if(!GridwiseGemm::CheckValidity(arg) || arg.KBatch > 1) + if(!GridwiseGemm::CheckValidity(arg)) { throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); } index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch); + std::tie(gdx, gdy, gdz) = + GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch * arg.KBatch); float ave_time = 0; @@ -387,10 +395,11 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 rotating_mem.Next(); // clear c mem if(arg_.KBatch > 1) - hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, - 0, - arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); + hipGetErrorString( + hipMemsetAsync(arg_.p_c_grid, + 0, + arg.Batch * arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); }; ave_time = ck::utility::launch_and_time_kernel_with_preprocess( @@ -889,7 +898,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + index_t KBatch = 1) { return Argument{static_cast(p_a), static_cast(p_b), @@ -909,7 +919,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 Batch, a_element_op, b_element_op, - c_element_op}; + c_element_op, + KBatch}; } static auto MakeInvoker() { return Invoker{}; } @@ -934,7 +945,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 index_t BatchStrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) override + CElementwiseOperation c_element_op, + index_t KBatch = 1) override { return std::make_unique(static_cast(p_a), static_cast(p_b), @@ -954,7 +966,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 Batch, a_element_op, b_element_op, - c_element_op); + c_element_op, + KBatch); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index c7038ed4fa..e5a31f8d1f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -41,7 +41,7 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); GridwiseGemm::template Run( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, @@ -76,7 +76,7 @@ __global__ void __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); GridwiseGemm::template Run_2Lds( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, @@ -639,27 +639,27 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 struct SplitKBatchOffset { - __device__ SplitKBatchOffset(Argument& karg) + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) { if constexpr(is_same_v) { - a_k_split_offset = blockIdx.z * karg.KRead; + a_k_split_offset = k_id * karg.KRead; } else if constexpr(is_same_v) { - a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + a_k_split_offset = k_id * karg.KRead * karg.StrideA; } if constexpr(is_same_v) { - b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + b_k_split_offset = k_id * karg.KRead * karg.StrideB; } else if constexpr(is_same_v) { - b_k_split_offset = blockIdx.z * karg.KRead; + b_k_split_offset = k_id * karg.KRead; } - if(blockIdx.z < static_cast(karg.KBatch - 1)) + if(k_id < karg.KBatch - 1) { karg.K = karg.KRead; } diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp index 5db041de09..21cef335c5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp @@ -52,6 +52,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 160, 64, 8, 8, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 32, 32, 1, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 128, 64, 8, 8, 32, 32, 5, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_f8_f8_bf16/device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_f8_f8_bf16/device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp index 355dc3212b..552ac3cd00 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_f8_f8_bf16/device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_f8_f8_bf16/device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp @@ -42,6 +42,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std //##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + #ifdef __gfx94__ // Compute friendly DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, @@ -72,6 +73,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std: //##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, diff --git a/profiler/include/profiler/profile_gemm_universal_batched_impl.hpp b/profiler/include/profiler/profile_gemm_universal_batched_impl.hpp index 53f81162ac..f4300af8d8 100644 --- a/profiler/include/profiler/profile_gemm_universal_batched_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_batched_impl.hpp @@ -48,6 +48,7 @@ bool profile_gemm_universal_batched_impl(int do_verification, int StrideB, int StrideC, int BatchCount, + int KBatch, int n_warmup, int n_iter, uint64_t rotating = 0) @@ -147,89 +148,100 @@ bool profile_gemm_universal_batched_impl(int do_verification, float best_ave_time = 0; float best_tflops = 0; float best_gb_per_sec = 0; + float best_kbatch = 0; // profile device op instances for(auto& op_ptr : op_ptrs) { - std::unique_ptr argument_ptr; - // false branch for multi d dl kernel + std::vector kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38}; - argument_ptr = - op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - {}, - static_cast(c_device_buf.GetDeviceBuffer()), - M, - N, - K, - BatchCount, - StrideA, - StrideB, - {}, - StrideC, - BatchStrideA, - BatchStrideB, - {}, - BatchStrideC, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) + if(KBatch > 0) { - // re-init C to zero before profiling next kernel - c_device_buf.SetZero(); + kbatch_list = {KBatch}; + } - std::string op_name = op_ptr->GetTypeString(); + for(std::size_t i = 0; i < kbatch_list.size(); i++) + { + auto kbatch_curr = kbatch_list[i]; - float ave_time = invoker_ptr->Run( - argument_ptr.get(), - StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter, true, rotating_count}); + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + {}, + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + BatchCount, + StrideA, + StrideB, + {}, + StrideC, + BatchStrideA, + BatchStrideB, + {}, + BatchStrideC, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + kbatch_curr); - std::size_t flop = std::size_t(2) * BatchCount * M * N * K; + auto invoker_ptr = op_ptr->MakeInvokerPointer(); - std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + - sizeof(CDataType) * M * N) * - BatchCount; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << op_name << std::endl; - - if(tflops > best_tflops) + if(op_ptr->IsSupportedArgument(argument_ptr.get())) { - best_op_name = op_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - } + std::string op_name = op_ptr->GetTypeString(); - if(do_verification) - { - c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); + float ave_time = invoker_ptr->Run( + argument_ptr.get(), + StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter, true, rotating_count}); - pass = pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result); + std::size_t flop = std::size_t(2) * BatchCount * M * N * K; - if(do_log) + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << op_name << ", KBatch " << kbatch_curr << std::endl; + + if(tflops > best_tflops) { - LogRangeAsType(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl; - LogRangeAsType(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl; - LogRangeAsType(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",") - << std::endl; - LogRangeAsType( - std::cout << "c_device: ", c_g_m_n_device_result.mData, ",") - << std::endl; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); + + pass = pass & ck::utils::check_err(c_g_m_n_device_result, c_g_m_n_host_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "c_host: ", c_g_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_g_m_n_device_result.mData, ",") + << std::endl; + } } } - } - else - { - std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } } } @@ -270,8 +282,8 @@ bool profile_gemm_universal_batched_impl(int do_verification, std::cout << " B = " << BatchCount << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA << " StrideB = " << StrideB << " StrideC = " << StrideC - << ": " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec - << " GB/s, " << best_op_name << std::endl; + << " KBatch = " << best_kbatch << ": " << best_ave_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; return pass; } diff --git a/profiler/src/profile_gemm_universal_batched.cpp b/profiler/src/profile_gemm_universal_batched.cpp index 4afef8e55f..d57511fbfc 100644 --- a/profiler/src/profile_gemm_universal_batched.cpp +++ b/profiler/src/profile_gemm_universal_batched.cpp @@ -31,7 +31,7 @@ enum struct GemmDataType int profile_batched_gemm_universal(int argc, char* argv[]) { - if(argc != 18 && argc != 21) + if(argc != 19 && argc != 22) { // clang-format off printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); @@ -44,11 +44,11 @@ int profile_batched_gemm_universal(int argc, char* argv[]) printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg6: print tensor value (0: no; 1: yes)\n"); printf("arg7: time kernel (0=n0, 1=yes)\n"); - printf("arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount\n"); + printf("arg8 to 18: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount, KBatch\n"); printf("optional:\n"); - printf("arg18: number of warm-up cycles (default 1)\n"); - printf("arg19: number of iterations (default 10)\n"); - printf("arg20: memory for rotating buffer (default 0, size in MB)\n"); + printf("arg19: number of warm-up cycles (default 1)\n"); + printf("arg20: number of iterations (default 10)\n"); + printf("arg21: memory for rotating buffer (default 0, size in MB)\n"); // clang-format on exit(1); } @@ -56,11 +56,11 @@ int profile_batched_gemm_universal(int argc, char* argv[]) int n_warmup = 1; int n_iter = 10; uint64_t rotating = 0; - if(argc == 21) + if(argc == 22) { - n_warmup = std::stoi(argv[18]); - n_iter = std::stoi(argv[19]); - rotating = std::stoull(argv[20]) * 1024 * 1024; + n_warmup = std::stoi(argv[19]); + n_iter = std::stoi(argv[20]); + rotating = std::stoull(argv[21]) * 1024 * 1024; } const auto data_type = static_cast(std::stoi(argv[2])); @@ -83,6 +83,7 @@ int profile_batched_gemm_universal(int argc, char* argv[]) const int BatchStrideC = std::stoi(argv[16]); const int BatchCount = std::stoi(argv[17]); + const int KBatch = std::stoi(argv[18]); #if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) using F8 = ck::f8_t; @@ -159,6 +160,7 @@ int profile_batched_gemm_universal(int argc, char* argv[]) StrideB_, StrideC_, BatchCount, + KBatch, n_warmup, n_iter, rotating);