Files
composable_kernel/profiler/src/profile_gemm_universal_streamk.cpp
Enrico Degregori b740380906 Wmma support for multiple Ds based GEMMs (#2613)
* Fixed cmake errors related to  gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8"

* Fixed cmake build errors related to test_fp8

* Updates to support mixed precision


(cherry picked from commit e65d71180393e7b66169c56565a6bac740427de6)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Adding support for RRR, F8xF16xF16 gemm_universal_wmma - wip


(cherry picked from commit f8c06322df0abcbd5945a56cdf5bffe56480f9f0)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Added support for F8xF16xF16 to gemm_wmma_universal


(cherry picked from commit 15c851de6daa513a12c2e3af299bab0176175fb5)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Added support for F16xF8xF16 to gemm_wmma_universal

* Added support for BF16xI4xBF16 to gemm_wmma_universal


(cherry picked from commit c6a4a69d2d43d59bae8bdabfae80d648646f217e)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Added support for F16xI4xF16 to gemm_wmma_universal

* Fixed IsSupportedArgument to check ComputeTypeA, ComputeTypeB instead of ADataType, BDataType

* Added missing test class for FP16_KM_NK

* Pre-commit hooks fixes

* Added padding instances for f16xf16xf16

* Fixed cmake errors related to  gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8"


(cherry picked from commit 5bdc993dbf)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Fixed cmake build errors related to test_fp8


(cherry picked from commit 12176616b6)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Ammending changes for adding support for padding instances for f16xf16xf16

* Fixes for padding instances for f16xf16xf16

* Added padding instances for bf16xbf16, f8xf8

* Added packed instances for bf16xi4xbf16

* Added padding instances for f8xf16xf16

* Added padding instances for f16xf8xf16, f16xi4xf16

* Fixed typos for bf16xbf16xbf16 padding instances

* Fixed typos for padded instances

* Added tests for fp16, KM_KN and KM_NK

* Padding not supported for when BDataType is pk_i4_t. Added fix for correct check and removed padding instances.

* Fixed typos

* Updated the set of tests for FP16

* Updated the set of tests for FP16

* Fix typo

* Moved f16xi4 test under the correct data layout group

* example for gemm_universal_bf16

* Adding examples for gemm_wmma instances

* Added the  missing parameters

* Fixed review comments and added executable to cmakeLists

* Fixing clang format

* Fixing build erros

* Fixed compilation failure.

* Modified some code as per gemm_universal_examples

* Fixed the gemm specialization error

* Fixed the build errors.

* Fix strides of a/b_thread_desc

The descriptors are larger than needed (even though the compiler don't alloc registers for unused values).

* Load in M/NRepeat dims with thread copy's slice instead of a loop

* Clone BlockwiseGemmXdlops_pipeline_v1 for WMMA implementation

* Implement Intrawave and Interwave variants of pipeline v1

* Add instances for Interwave and Intrawave v1

* Add instances with ABlockLdsExtraM and BBlockLdsExtraN = 0

* Remove instances that are too slow (mostly because of register spilling)

* Add a workaround for fp8/bf8->f32 packed conversion issue

* Add instances for Interwave and Intrawave v1

* Enable profiling of mixed precision with f8 and int4 on WMMA

* Fix segfault in profiler when B is pk_i4_t

b_device_buf's size in bytes is larger than b_k_n_permute so b_device_buf.ToDevice reads out-of-bounds.

* Remove instances that are too slow (mostly because of register spilling)

* Add missing add_device_gemm_wmma_universal_f8_f8_bf16 declarations

* Add test case for bf16_i4

* Add missing Regular tests

* Add test_gemm_universal_xdl/wmma_fp16 to REGRESSION_TESTS

They take more than 30 seconds

* Fix a bug that fp16_i4 validation passes only with PermuteB

A permutation required by conversion from pk_i4_t to half_t does not
depend on PermuteB, they can be used independently.

* Use PermuteB with f16_i4 in most instances (as xdl)

Some instances use PermuteB = false for checking correctness.
See also the previous commit.

* Fix cache flushing for pk_i4

* Add mixed precision examples

* Disable all tests and instances with f8 on gfx11

Even though f8_f16 and f16_f8 don't require f8 WMMA instructions,
gfx11 still lacks hardware instructions for fast f8->f32 conversion.

* Add FP16 KM_NK and KM_KN test suites for XDL

These tests were added to common .inc for better testing of WMMA instances

* Support multiple D in GridwiseGemm_wmma_cshuffle_v3

DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters.

* Use ThreadGroupTensorSliceTransfer_v7r3

* Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support

* Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma

* Implement DeviceGemmMultipleD_Wmma_CShuffleV3

* Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3

* Prepare gemma_add tests for adding wmma

* Add gemm_add_fastgelu instances and test

* Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API

ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use
DeviceGemmMultipleDSplitK instances there.

* removed unnecessary ck parts from compilation

* initial gemm_add_multiply instance implementations

* fixed profiler help message for gemm_add_multiply

* improved multiply_add profiler layout help

* fixed template arguments for test instances

* added test for gemm_add_multiply

* Support multiple D in GridwiseGemm_wmma_cshuffle_v3

DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters.

* Use ThreadGroupTensorSliceTransfer_v7r3

* Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support

* Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma

* Implement DeviceGemmMultipleD_Wmma_CShuffleV3

* Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3

* Prepare gemma_add tests for adding wmma

* Add gemm_add_fastgelu instances and test

* Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API

ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use
DeviceGemmMultipleDSplitK instances there.

* switched to splitK interface

* log print added to splitk benchmarks

* revert main cmake comments

* newline change reverted

* added add_fastgelu instances

* revert unintended change in xdl add_fastgelu

* created gemm_add_add_fastgelu instances

* created fastegelu instances

* added tests for all splitk fastgelus

* Added tests.

* multiply_add instances created

* updates to add_multiply splitk instances

* splitk xdl test fixes

* added wmma multiply_multiply instances

* fixed ONLY_XDL_AND_WMMA_KERNELS tag

* Added gemm_add examples for wmma v1 and v3

* fixed / workarounded i8 instances

* Modified the v3 code to added one fp16 bxdl instance.

* added bf16 xdl instance.

* adding gemm_add wmma_cshuffle and other support


(cherry picked from commit ec447e7f564095ea969eddc39ec77b843aa52976)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* add instances into camkelists


(cherry picked from commit 23bf2d2771c939ea3ca7f493433c55255bffd08e)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* This is work in progress, edited the template parameters in order to build

(cherry picked from commit b4fde8a3314cb44659c4bbda35f1a0133c63dc41)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* temp work saved, changed the BDataType to f16 or bf16 since wmma currently not support non-equal A and B datatype


(cherry picked from commit 22fbd68f1db458ab50780a394ee2544c7a1484d1)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* added datatype and use clang-format-12


(cherry picked from commit ae4e853682ef1bb27784b2f965b4a66b3751ceec)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* Fixing build errors

* Added instances for v3

* Adding instances and executables

* Code update of template parameters modified.

* Renamed file.

* Added tests.

* resolved error tests.

* Fixing build errors

* Updated comments

* removed the changes as per the MR review comment.

* Updated tests.

* fp8 instances - not tested

* Restored the Cmake file that was reverted by mistake during rebase.

* fixed wmma_op test

* Updated comments.

* Updated the template parameter description

* fixed rdna4 instances

* fixed back compatibility on gfx11

* cleanups

* fix ckProfiler

* one more cmake fix

* added fp8 instances

* Updated tests to ad BF16 instances as per review comment

* Added include file and cleaned up(as per review comment)

* Updated and optimized the example code for all types.

* Fixed clang format

* Resolve "Implement `device_gemm_bilinear` for RDNA4"

* test generalization to handle FP16 shuffle better

* added missing changes

* Added bf16 wmma instance for add_relu

* Added f16 wmma instance and corrected bf16 instance errors.

* Added instances to Cmake

* Modified the template parameters to make the instances work.

* Fixed typo in profiler

* Added v3 instances for gemm_add_relu

* addressed core review comments

* Added test for gemm_add_relu wmma instance

* Cleaned up the code.

* Added examples for gemm_add_relu

* Fixing typo to resolve build errors.

* Fixes applied to fix  the precision loss.

* fix billinear test after merge

* Removed the old wmma instances.

* Added wrapper and renamed the wmma_v3 instances

* Updated copyrights and added wrappers.

* Fixes applied according to review comments

* Apply 1 suggestion(s) to 1 file(s)

Co-authored-by: Robin Voetter <robin@streamhpc.com>

* Removed the old wmma instances.

* Updated wrapper for the v3 instances

* removed the old wmma examples

* Renamed the v3 instances

* Deleted the  gtest file added by mistake.

* Updated thge profiler with wrapper

* Fixed test errors.

* Fixed the review comments

* Fixed the if condition MACROS.

* REVERTED THE PROFILER CHANGES

* Revert "REVERTED THE PROFILER CHANGES"

This reverts commit 21cb98546c.

* Revert "Fixed test errors."

This reverts commit 13efcc6fe1.

* Revert "Updated thge profiler with wrapper"

This reverts commit 536f86661d.

* Added missing wrapper instances

* Updated copyrights.

* Fixed typo.

* Fixed copyrights.

* Updated copyrights.

* updated copyrights.

* comments on the atomics workaround

* fixed cmake comment

* Fix bug from merge

* clang-format-18

* Fix compilation error

* Fix linking error

* Fix bug in add and add_relu examples

* Fix error including file (typo)

* Quick fix to compile examples for different targets

* Fix for multi target

* implemented f16 and bf16 instances for gemm_silu

* addressed review comments

* addressed review comments

* Fix clang format

* Fix clang format

---------

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>
Co-authored-by: apoorva <apoorva@streamhpc.com>
Co-authored-by: Anton Gorenko <anton@streamhpc.com>
Co-authored-by: Zoltan Lakatos <zoltan.lakatos@streamhpc.com>
Co-authored-by: Cenxuan <cenxuan@streamhpc.com>
Co-authored-by: Robin Voetter <robin@streamhpc.com>
Co-authored-by: Kiefer van Teutem <kiefer.van.teutem@streamhpc.com>
Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
2025-09-05 16:31:08 +02:00

224 lines
8.1 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_gemm_universal_streamk_impl.hpp"
#include "profiler_operation_registry.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
KM_KN_MN, // 2
KM_NK_MN, // 3
};
enum struct GemmDataType
{
F32_F32_F32, // 0
F16_F16_F16, // 1
BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3
F8_F16_F16, // 4
F16_F8_F16, // 5
F16_F16_F16_F8, // 6
F8_F8_BF16, // 7
};
#define OP_NAME "gemm_universal_streamk"
#define OP_DESC "Universal Streamk GEMM"
int profile_gemm_universal_streamk(int argc, char* argv[])
{
if(argc != 16 && argc != 19)
{
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8; 4: f8@f16; 5: f16@f8; 6: f16, "
"comp f8; 7: f8->bf16,)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("arg14: Stream-k select strategy 0: all DP, 1: 1-tile SK, 2: 2-tile SK\n");
printf("arg15: Grid-size, -1 for max persistent kernel occupancy\n");
printf("optional:\n");
printf("arg16: number of warm-up cycles (default 1)\n");
printf("arg17: number of iterations (default 10)\n");
printf("arg18: memory for rotating buffer (default 0, size in MB)\n");
exit(1);
}
int M;
int N;
int StrideA;
int StrideB;
// Analyze the unsupported matrix shapes, switch the M and N number
if(std::stoi(argv[9]) % 8 != 0 && std::stoi(argv[8]) % 8 == 0)
{
M = std::stoi(argv[9]);
StrideA = std::stoi(argv[12]);
N = std::stoi(argv[8]);
StrideB = std::stoi(argv[11]);
}
else
{
M = std::stoi(argv[8]);
StrideA = std::stoi(argv[11]);
N = std::stoi(argv[9]);
StrideB = std::stoi(argv[12]);
}
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const int K = std::stoi(argv[10]);
const int StrideC = std::stoi(argv[13]);
const int Streamk_sel = std::stoi(argv[14]);
const int Grid_size = std::stoi(argv[15]);
int n_warmup = 1;
int n_iter = 10;
uint64_t rotating = 0;
if(argc == 19)
{
n_warmup = std::stoi(argv[16]);
n_iter = std::stoi(argv[17]);
rotating = std::stoull(argv[18]) * 1024 * 1024;
}
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using F8 = ck::f8_t;
#endif
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = [&](auto a_type,
auto b_type,
auto comp_type,
auto acc_type,
auto c_type,
auto a_layout,
auto b_layout,
auto c_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using ComputeDataType = decltype(comp_type);
using AccDataType = decltype(acc_type);
using CDataType = decltype(c_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
bool pass = ck::profiler::profile_gemm_universal_streamk_impl<ADataType,
BDataType,
ComputeDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC,
Streamk_sel,
Grid_size,
n_warmup,
n_iter,
rotating);
return pass ? 0 : 1;
};
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F16{}, F16{}, F32{}, F32{}, F16{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F16{}, F16{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{});
}
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F16{}, F8{}, F32{}, F32{}, F16{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F16{}, F8{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F8{}, F16{}, F32{}, F32{}, F16{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F8{}, F16{}, F32{}, F32{}, F16{}, Row{}, Col{}, Row{});
}
#endif
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(BF16{}, BF16{}, F32{}, F32{}, BF16{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(BF16{}, BF16{}, F32{}, F32{}, BF16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_KN_MN)
{
return profile(BF16{}, BF16{}, F32{}, F32{}, BF16{}, Col{}, Row{}, Row{});
}
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::KM_NK_MN)
{
return profile(BF16{}, BF16{}, F32{}, F32{}, BF16{}, Col{}, Col{}, Row{});
}
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Row{}, Row{});
}
else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{});
}
#endif
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_universal_streamk);