mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
Add SplitK support into Batched GEMM V3 (#1729)
* add bmm api
* add bf16 multi_d
* add ckProfiler for bf16
* add ckProfiler files
* add more instance; fixed 64bit index issue
* fixed naming
* enabled batched Ds
* use long_index for ds offsets
* clean
* add bmm fp8 ckProfiler
* Update example/24_batched_gemm/batched_gemm_xdl_bf16_v3.cpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update example/24_batched_gemm/batched_gemm_xdl_fp8_rowwise_v3.cpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update example/24_batched_gemm/run_batched_gemm_example_rowwise.inc
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v1_default_instance.cpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_default_instance.cpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update profiler/src/profile_gemm_universal_batched.cpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* Update profiler/include/profiler/profile_gemm_universal_batched_impl.hpp
Co-authored-by: Bartłomiej Kocot <bartlomiejkocot98@gmail.com>
* clean
* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
* Update library/src/tensor_operation_instance/gpu/gemm_universal_batched/device_batched_gemm_xdl_universal_bf16_bf16_bf16/device_batched_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instance.cpp
* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
* Update include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp
* refactor batch offset func
* add splitk suppport into bmm_v3
* clean
* clean
* format
* fixed
* fix
---------
Co-authored-by: Jing Zhang <jizhan@fb.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
[ROCm/composable_kernel commit: 4d8fce33dd]
This commit is contained in:
@@ -31,7 +31,7 @@ enum struct GemmDataType
|
||||
|
||||
int profile_batched_gemm_universal(int argc, char* argv[])
|
||||
{
|
||||
if(argc != 18 && argc != 21)
|
||||
if(argc != 19 && argc != 22)
|
||||
{
|
||||
// clang-format off
|
||||
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
|
||||
@@ -44,11 +44,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
|
||||
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg6: print tensor value (0: no; 1: yes)\n");
|
||||
printf("arg7: time kernel (0=n0, 1=yes)\n");
|
||||
printf("arg8 to 17: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount\n");
|
||||
printf("arg8 to 18: M, N, K, StrideA, StrideB, StrideC, BatchStrideA, BatchStrideB, BatchStrideC, BatchCount, KBatch\n");
|
||||
printf("optional:\n");
|
||||
printf("arg18: number of warm-up cycles (default 1)\n");
|
||||
printf("arg19: number of iterations (default 10)\n");
|
||||
printf("arg20: memory for rotating buffer (default 0, size in MB)\n");
|
||||
printf("arg19: number of warm-up cycles (default 1)\n");
|
||||
printf("arg20: number of iterations (default 10)\n");
|
||||
printf("arg21: memory for rotating buffer (default 0, size in MB)\n");
|
||||
// clang-format on
|
||||
exit(1);
|
||||
}
|
||||
@@ -56,11 +56,11 @@ int profile_batched_gemm_universal(int argc, char* argv[])
|
||||
int n_warmup = 1;
|
||||
int n_iter = 10;
|
||||
uint64_t rotating = 0;
|
||||
if(argc == 21)
|
||||
if(argc == 22)
|
||||
{
|
||||
n_warmup = std::stoi(argv[18]);
|
||||
n_iter = std::stoi(argv[19]);
|
||||
rotating = std::stoull(argv[20]) * 1024 * 1024;
|
||||
n_warmup = std::stoi(argv[19]);
|
||||
n_iter = std::stoi(argv[20]);
|
||||
rotating = std::stoull(argv[21]) * 1024 * 1024;
|
||||
}
|
||||
|
||||
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
|
||||
@@ -83,6 +83,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
|
||||
const int BatchStrideC = std::stoi(argv[16]);
|
||||
|
||||
const int BatchCount = std::stoi(argv[17]);
|
||||
const int KBatch = std::stoi(argv[18]);
|
||||
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
|
||||
using F8 = ck::f8_t;
|
||||
@@ -159,6 +160,7 @@ int profile_batched_gemm_universal(int argc, char* argv[])
|
||||
StrideB_,
|
||||
StrideC_,
|
||||
BatchCount,
|
||||
KBatch,
|
||||
n_warmup,
|
||||
n_iter,
|
||||
rotating);
|
||||
|
||||
Reference in New Issue
Block a user