mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Add SplitK support into Batched GEMM V3 (#1729)
* add bmm api
* add bf16 multi_d
* add ckProfiler for bf16
* add ckProfiler files
* add more instance; fixed 64bit index issue
* fixed naming
* enabled batched Ds
* use long_index for ds offsets
* clean
* add bmm fp8 ckProfiler
* Update example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update example/24_batched_gemm/run_batched_gemm_example_rowwise.inc
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update profiler/src/profile_gemm_universal_batched.cpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update profiler/include/profiler/profile_gemm_universal_batched_impl.hpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* clean
* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
* Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp
* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
* refactor batch offset func
* add splitk suppport into bmm_v3
* clean
* clean
* format
* fixed
* fix
---------
Co-authored-by: Jing Zhang <jizhan@fb.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
[ROCm/composable_kernel commit: 4d8fce33dd]
This commit is contained in:
@@ -52,6 +52,9 @@ using device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances =
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 2, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 64, 8, 8, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 160, 64, 8, 8, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 2, 1, S<1, 32, 1, 8>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 32, 32, 1, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 128, 64, 8, 8, 32, 32, 5, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, BF16, BF16, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
|
||||
|
||||
@@ -42,6 +42,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
#ifdef __gfx94__
|
||||
// Compute friendly
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 16, 16, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>,
|
||||
@@ -72,6 +73,7 @@ using device_batched_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std:
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Row, Col, DsLayout, Row, F8, F8, DsDataType, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>,
|
||||
|
||||
Reference in New Issue
Block a user