Implement batched gemm gemm for RDNA (3 and 4) (#2612)

* Create new copies of existing device struct and gridwise struct for batched_gemm_softmax_gemm and disable the softmax part. Still based on old wmma pipelines. Also copy the example and remove the softmax part from the reference calculation. Works and results match reference except for tiny float errors in problem 2.

* Turn DeviceBatchedGemmGemm_Wmma_CShuffleV3 into a proper DeviceBatchedGemmGemm derived class, with the right argument and invoker functions. Update example to use new definitions.

* Remove unused cross-attention and self-attention kernels, arguments, and invokers. Also remove other unused Argument types.

* Remove masking related code, test unusual sizes in example.

* Remove remaining softmax related code from GridwiseBatchedGemmGemm_wmma_cshuffle_v3 and example.

* Remove code related to numDims, bias, and TensorSpec from Device struct and example.

* Add layout template parameters to device struct

* Move (NPerBlock, LTilePerBlock) device struct template arguments up by two places to match XDL template argument ordering.

* Merge accumulation data types into one type to match XDL device struct.

* Remove NPerWmma template parameter from device struct and just set it equal to LPerWmma. Now device struct template params exactly match those for XDL batched gemm gemm.

* Add support for RCCR layout and test this in example

* Add batched_gemm_gemm_wmma to instance library + profiler, and add gtest just like for xdl.

* Add RCCR instance and additional RCRR instance to library.

* Remove unused permute and alpha related code. Time all tests. Fix B1 strides in argument verification.

* Remove references to G0, G1 in favor of batch, reduce dimensionality of length and stride arrays.

* Managed to replace old wmma gridwise pipeline and blockwise struct with new wmma blockwise pipeline. Some cleanup required but all tests pass.

* Make TransposeC a proper template parameter that gets passed all the way from BlockGemmPipeline_Selector to WmmaGemm so we can use the correct settings for bacthed gemm gemm as well as regular gemm. Gemm universal tests now pass again.

* Replace old LoopSched and PipelineVer params with BlockwiseGemm pipeline equivalents, and use these in instance factory. The v3 pipeline does not work yet, but v1 works for intrawave and interwave.

* Adapt the A wave descriptor to deal with RDNA4 wmma. This fixes batched gemm gemm functionality on RDNA4.

* Fixed two aspects of the v3 pipeline that were incorrect: First of all the blockwise copy operator was invoked once too many in all cases (RunRead and move window), which broke batched gemm gemm when the blockwise pipeline was used multiple times. Furthermore we should be using the mainloop (hotloop) for num_k_loop >=2 instead of num_k_loop >=3. Now we can use support any K dimension.

* Remove num prefetch parameter from gridwise struct since we don't use it and it doesn't do anything,

* Remove unused non-lds paths.

* Test  and update the IsSupportedArgument() and CheckValidity() functions for all layouts + padding modes and various problem sizes.

* Add a lot of instances to the profiler with various blocksizes and pipelines, all verified.

* Add support for BF16: instance library, tests, and examples.

* Add examples for int8 and fp8, had to add type_convert_sp template specializations for the latter.

* Template the library instance lists and add default padding instances.

* Move memory calculations from the kernel to the Argument contructor. Also actually parse and use the user-provided batch strides.

* Actually parse and use user-provided regular strides.

* More refactor: remove references to multiple dims per dims, and g0 / g1. Also move xdl specific test utils out of generic test util header.

* Small post-rebase-on-develop fix due to bscale-related pipeline changes. All tests rerun + tested bscale and regular gemm.

* Introduce the correct GetCThreadDescriptor function in the blockwise gemm pipelines for the TransposeC=true case. It turns out to be identical for our batched gemm gemm (gemm0) usecases, but could theoretically be different for wmma_gemm instances with smaller-than-4-byte output data size.

* Remove unused NumPrefetch template parameter, we don't need to match the XDL template params one-to-one.

* Implement proper TailNum and HasMainLoop template parameters for the v3 pipeline. Now the Run() function knows at compile time whether there are 1, 2, or more loops in total, and adds or removes sections accordingly. It still uses the blockwise copy operators the correct amount of times.

* Add print lambda with env check and file and func to device and gridwise level compatibility error messages. Also respect compatibility in example script.

* RDNA3 does not support fp8

[ROCm/composable_kernel commit: 7330ec37ee]
This commit is contained in:
Kiefer van Teutem
2025-09-04 23:10:24 +02:00
committed by GitHub
parent 80cc2b7bca
commit ecc4a470ec
27 changed files with 3679 additions and 165 deletions

View File

@@ -1,6 +1,12 @@
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_bf16 batched_gemm_gemm_wmma_cshuffle_v3_bf16.cpp)
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_fp8 batched_gemm_gemm_wmma_cshuffle_v3_fp8.cpp)
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_fp16 batched_gemm_gemm_wmma_cshuffle_v3_fp16.cpp)
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_int8 batched_gemm_gemm_wmma_cshuffle_v3_int8.cpp)
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
endif(USE_BITINT_EXTENSION_INT4)

View File

@@ -0,0 +1,276 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
/*
Gemm + Gemm fused operation. Computes C_g_m_n = (A_g_m_k * B0_g_k_l) * B1_g_l_n
|------------------|
Gemm0
|-----------------------------|
Gemm1
*/
static constexpr auto PipeSched = ck::BlockGemmPipelineScheduler::Interwave;
static constexpr auto PipelineVer = ck::BlockGemmPipelineVersion::v1;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
// clang-format off
// #define CK_MHA_USE_RCCR_LAYOUT
#define CK_MHA_USE_WAVE_1
// #define CK_MHA_USE_WAVE_2
// #define CK_MHA_USE_WAVE_4
// #define CK_MHA_USE_WAVE_8
#ifdef CK_MHA_USE_RCCR_LAYOUT
using DeviceMHAFactory =
std::tuple<
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Col, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
32,
// Gemm 0
16, 64, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
PipeSched, PipelineVer>
>;
#else
using DeviceMHAFactory =
std::tuple<
#ifdef CK_MHA_USE_WAVE_1
// 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Row, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
32,
// Gemm 0
16, 128, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
PipeSched, PipelineVer>,
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Row, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
32,
// Gemm 0
16, 64, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 16, 1, 2>, 8,
PipeSched, PipelineVer>
#endif
#ifdef CK_MHA_USE_WAVE_2
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Row, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
64,
// Gemm 0
32, 128, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
PipeSched, PipelineVer>,
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Row, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
64,
// Gemm 0
32, 64, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 32, 1, 2>, 8,
PipeSched, PipelineVer>
#endif
#ifdef CK_MHA_USE_WAVE_4
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Row, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
128,
// Gemm 0
64, 128, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
PipeSched, PipelineVer>,
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Row, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
128,
// Gemm 0
64, 64, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 4, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 64, 1, 2>, 8,
PipeSched, PipelineVer>
#endif
#ifdef CK_MHA_USE_WAVE_8
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Row, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
256,
// Gemm 0
128, 128, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
PipeSched, PipelineVer>,
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
Row, Col, Row, Row,
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
GemmSpec,
256,
// Gemm 0
128, 128, 64, 64, 64, 8, 8,
// Gemm 1
8,
16, 16,
// Per repeat = wave_m = wave_num, wave_n = 1
1, 8, 4,
// ABlockTransfer MK -> K0 M K1
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B0BlockTransfer LK -> K0 L K1
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
// B1BlockTransfer NL -> L0 N L1
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
// CShuffleBlockTransfer MN
1, 1, S<1, 128, 1, 2>, 8,
PipeSched, PipelineVer>
#endif
>;
#endif
// clang-format on
// Ref Gemm0
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B0DataType,
AccDataType,
AccDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp>;
// Ref Gemm1
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
B1DataType,
CDataType,
AccDataType,
AElementOp,
B1ElementOp,
CElementOp>;
#include "run_batched_gemm_gemm_wmma_cshuffle_v3.inc"
int main(int argc, char* argv[])
{
bool is_supported = ck::is_gfx11_supported() || ck::is_gfx12_supported();
if(!is_supported)
{
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
<< std::endl;
return 0;
}
return run(argc, argv);
}

View File

@@ -0,0 +1,37 @@
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
using BF16 = ck::bhalf_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = BF16;
using B0DataType = BF16;
using B1DataType = BF16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = BF16;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"

View File

@@ -0,0 +1,37 @@
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F16;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"

View File

@@ -0,0 +1,34 @@
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = ck::f8_t;
using B0DataType = ck::f8_t;
using B1DataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
using CDataType = ck::f8_t;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"

View File

@@ -0,0 +1,34 @@
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = int8_t;
using B0DataType = int8_t;
using B1DataType = int8_t;
using AccDataType = int32_t;
using CShuffleDataType = int32_t;
using CDataType = int8_t;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"

View File

@@ -0,0 +1,304 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
int run(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck::index_t M = 113;
#ifdef CK_MHA_USE_RCCR_LAYOUT
ck::index_t N = 480; // Must be multiple of 8 even with padding.
#else
ck::index_t N = 477;
#endif
ck::index_t K = 200; // Must be multiple of 8 even with padding.
ck::index_t O = 208; // Must be multiple of 8 even with padding.
ck::index_t G = 91; // Batch
float alpha = 1;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
O = std::stoi(argv[7]);
G = std::stoi(argv[8]);
alpha = std::stof(argv[9]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 8: M, N, K, O, G\n");
printf("arg9: scale (alpha)\n");
exit(0);
}
std::vector<ck::index_t> a_g_m_k_lengths{G, M, K};
std::vector<ck::index_t> a_g_m_k_strides{M * K, K, 1}; // A layout [G, M, K]
std::vector<ck::index_t> b0_g_n_k_lengths{G, N, K};
std::vector<ck::index_t> b0_g_n_k_strides{N * K, K, 1}; // B0 layout [G, N, K]
std::vector<ck::index_t> b1_g_o_n_lengths{G, O, N};
#ifdef CK_MHA_USE_RCCR_LAYOUT
std::vector<ck::index_t> b1_g_o_n_strides{N * O, N, 1}; // B1 layout [G, O, N]
#else
std::vector<ck::index_t> b1_g_o_n_strides{N * O, 1, O}; // B1 layout [G, N, O]
#endif
std::vector<ck::index_t> c_g_m_o_lengths{G, M, O};
std::vector<ck::index_t> c_g_m_o_strides{M * O, O, 1}; // C layout [G, M, O]
Tensor<ADataType> a_g_m_k(a_g_m_k_lengths, a_g_m_k_strides);
Tensor<B0DataType> b0_g_n_k(b0_g_n_k_lengths, b0_g_n_k_strides);
Tensor<B1DataType> b1_g_o_n(b1_g_o_n_lengths, b1_g_o_n_strides);
Tensor<CDataType> c_g_m_o_host_result(c_g_m_o_lengths, c_g_m_o_strides);
Tensor<CDataType> c_g_m_o_device_result(c_g_m_o_lengths, c_g_m_o_strides);
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b0_g_n_k: " << b0_g_n_k.mDesc << std::endl;
std::cout << "b1_g_o_n: " << b1_g_o_n.mDesc << std::endl;
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 2:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
break;
case 3:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
break;
case 4: // A, B0, B1 1
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 5: // Rand: b1 b0; unit: a
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 6: // Rand: a b0 ; unit: B1
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 7: // Rand: a b1 ; unit: b0
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
case 8: // Rand: a ; unit: b0 b1
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 9: // Rand: b0 ; unit: a b1
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 10: // Rand: b1 ; unit: a b0
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
break;
default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
b0_g_n_k.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
b1_g_o_n.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
}
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_g_n_k.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_g_o_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_g_m_k.mData.data());
b0_device_buf.ToDevice(b0_g_n_k.mData.data());
b1_device_buf.ToDevice(b1_g_o_n.mData.data());
auto a_element_op = AElementOp{};
auto b0_element_op = B0ElementOp{};
auto acc0_element_op = Acc0ElementOp{alpha};
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
float best_perf = .0;
float best_time = .0;
int not_pass = 0;
std::string best_kernel = "";
printf("Verification: %s\n", do_verification ? "ON" : "OFF");
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
auto gemm = DeviceMHAInstance{};
auto invoker_ptr = gemm.MakeInvokerPointer();
auto argument_ptr =
gemm.MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
G, // Batch,
a_g_m_k_strides[1], // StrideA,
b0_g_n_k_strides[1], // StrideB0,
#ifdef CK_MHA_USE_RCCR_LAYOUT
b1_g_o_n_strides[1], // StrideB1,
#else
b1_g_o_n_strides[2], // StrideB1,
#endif
c_g_m_o_strides[1], // StrideC,
a_g_m_k_strides[0], // BatchStrideA
b0_g_n_k_strides[0], // BatchStrideB0
b1_g_o_n_strides[0], // BatchStrideB1
c_g_m_o_strides[0], // BatchStrideC
a_element_op,
b0_element_op,
acc0_element_op,
b1_element_op,
c_element_op);
if(!gemm.IsSupportedArgument(argument_ptr.get()))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return;
}
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
G;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm.GetTypeString() << std::endl;
if(tflops > best_perf)
{
best_perf = tflops;
best_time = ave_time * 1000;
best_kernel = gemm.GetTypeString();
}
if(do_verification)
{
c_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
Tensor<B0DataType> b0_g_k_n({G, K, N});
Tensor<B1DataType> b1_g_n_o({G, N, O});
Tensor<AccDataType> acc0_g_m_n({G, M, N}); // scratch object after gemm0
Tensor<ADataType> a1_g_m_n({G, M, N}); // scratch object after conversion
// permute
b0_g_n_k.ForEach(
[&](auto& self, auto idx) { b0_g_k_n(idx[0], idx[2], idx[1]) = self(idx); });
b1_g_o_n.ForEach(
[&](auto& self, auto idx) { b1_g_n_o(idx[0], idx[2], idx[1]) = self(idx); });
// gemm 0
auto ref_gemm0 = ReferenceGemm0Instance{};
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
ref_gemm0_invoker.Run(ref_gemm0_argument);
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
// Passthrough instead of softmax, DOES involve data type conversion.
a1_g_m_n(idx) = ck::type_convert<ADataType, AccDataType>(self(idx));
});
// gemm1
auto ref_gemm1 = ReferenceGemm1Instance{};
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n,
b1_g_n_o,
c_g_m_o_host_result,
PassThrough{},
b1_element_op,
c_element_op);
ref_gemm1_invoker.Run(ref_gemm1_argument);
// default absolute error and relative error is 0.001
double rtol = 1e-3;
double atol = 1e-3;
// when BF16 is taken, set absolute error and relative error to 0.01
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
{
rtol = 1e-2;
atol = 1e-2;
}
bool this_run_verification = ck::utils::check_err(c_g_m_o_device_result.mData,
c_g_m_o_host_result.mData,
"Error: Incorrect results!",
rtol,
atol);
printf("Verification: %s, Pass: %s\n",
do_verification ? "ON" : "OFF",
this_run_verification ? "YES" : "NO");
if(!this_run_verification)
{
not_pass = 1;
printf("%d th MHA instance verification Failed \n", i.value);
}
}
});
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Problem Size: G: " << G << ", M: " << M << ", N: " << N << ", K: " << K
<< ", O: " << O << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time
<< " us" << std::endl;
std::cout << "---------------------------------------------------------------------------------"
"-----------"
<< std::endl;
return not_pass;
}