mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Add SplitK support into Batched GEMM V3 (#1729)
* add bmm api * add bf16 multi_d * add ckProfiler for bf16 * add ckProfiler files * add more instance; fixed 64bit index issue * fixed naming * enabled batched Ds * use long_index for ds offsets * clean * add bmm fp8 ckProfiler * Update example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update example/24_batched_gemm/run_batched_gemm_example_rowwise.inc Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update profiler/src/profile_gemm_universal_batched.cpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * Update profiler/include/profiler/profile_gemm_universal_batched_impl.hpp Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com> * clean * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp * refactor batch offset func * add splitk suppport into bmm_v3 * clean * clean * format * fixed * fix --------- Co-authored-by: Jing Zhang <jizhan@fb.com> Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -78,14 +78,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
0, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
0, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
|
||||
Reference in New Issue
Block a user