mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-27 00:14:35 +00:00
[GEMM] gemm_universal related optimization (#1453)
* replace buffer_atomic with global_atomic
* fixed global_atomic_add
* added bf16 atomic_add
* format
* clang-format-12
* clean
* clean
* add guards
* Update gtest.cmake
* enabled splitk_gemm_multi_d
* format
* add ckProfiler
* format
* fixed naming
* format
* clean
* clean
* add guards
* fix clang format
* format
* add kbatch printout
* clean
* Add rocm6.2 related gemm optimization
* Limit bf16 atomic usage
* remove redundant RCR gemm_universal instance
* Add RRR fp8 gemm universal instance
* Bug fix
* Add GPU_TARGET guard to FP8/BF8 target
* bug fix
* update cmake
* remove all fp8/bf8 example if arch not support
* Enable fp8 RRR support in ckProfiler
* limit greedy-reverse flag to gemm_universal in ckProfiler
---------
Co-authored-by: Jing Zhang <jizhan@fb.com>
Co-authored-by: Jing Zhang <jizhan@meta.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
[ROCm/composable_kernel commit: 3049b5467c]
This commit is contained in:
@@ -34,7 +34,7 @@ enum struct GemmDataType
|
||||
|
||||
int profile_gemm_multiply_multiply(int argc, char* argv[])
|
||||
{
|
||||
if(argc != 16 && argc != 19)
|
||||
if(argc != 16 && argc != 20)
|
||||
{
|
||||
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
|
||||
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: "
|
||||
@@ -50,9 +50,10 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
|
||||
printf("arg7: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg8 to 15: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n");
|
||||
printf("optional:\n");
|
||||
printf("arg16: number of warm-up cycles (default 1)\n");
|
||||
printf("arg17: number of iterations (default 10)\n");
|
||||
printf("arg18: memory for rotating buffer (default 0, size in MB)\n");
|
||||
printf("arg16: number of kbatch (default 1)\n");
|
||||
printf("arg17: number of warm-up cycles (default 1)\n");
|
||||
printf("arg18: number of iterations (default 10)\n");
|
||||
printf("arg19: memory for rotating buffer (default 0, size in MB)\n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
@@ -76,11 +77,13 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
|
||||
int n_warmup = 1;
|
||||
int n_iter = 10;
|
||||
uint64_t rotating = 0;
|
||||
if(argc == 19)
|
||||
int KBatch = 1;
|
||||
if(argc == 20)
|
||||
{
|
||||
n_warmup = std::stoi(argv[16]);
|
||||
n_iter = std::stoi(argv[17]);
|
||||
rotating = std::stoull(argv[18]) * 1024 * 1024;
|
||||
KBatch = std::stoi(argv[16]);
|
||||
n_warmup = std::stoi(argv[17]);
|
||||
n_iter = std::stoi(argv[18]);
|
||||
rotating = std::stoull(argv[19]) * 1024 * 1024;
|
||||
}
|
||||
|
||||
using F32 = float;
|
||||
@@ -146,6 +149,7 @@ int profile_gemm_multiply_multiply(int argc, char* argv[])
|
||||
(StrideD0 < 0) ? DefaultStrideD0 : StrideD0,
|
||||
(StrideD1 < 0) ? DefaultStrideD1 : StrideD1,
|
||||
(StrideE < 0) ? DefaultStrideE : StrideE,
|
||||
KBatch,
|
||||
n_warmup,
|
||||
n_iter,
|
||||
rotating);
|
||||
|
||||
Reference in New Issue
Block a user