mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Add SplitK support into Batched GEMM V3 (#1729)
* add bmm api
* add bf16 multi_d
* add ckProfiler for bf16
* add ckProfiler files
* add more instance; fixed 64bit index issue
* fixed naming
* enabled batched Ds
* use long_index for ds offsets
* clean
* add bmm fp8 ckProfiler
* Update example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update example/24_batched_gemm/run_batched_gemm_example_rowwise.inc
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update profiler/src/profile_gemm_universal_batched.cpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update profiler/include/profiler/profile_gemm_universal_batched_impl.hpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* clean
* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
* Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp
* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
* refactor batch offset func
* add splitk suppport into bmm_v3
* clean
* clean
* format
* fixed
* fix
---------
Co-authored-by: Jing Zhang <jizhan@fb.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
[ROCm/composable_kernel commit: 4d8fce33dd]
This commit is contained in:
@@ -78,14 +78,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
1, // ABlockLdsExtraM
|
||||
0, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
1, // BBlockLdsExtraN
|
||||
0, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
|
||||
Reference in New Issue
Block a user