mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Implement batched gemm gemm for RDNA (3 and 4) (#2612)
* Create new copies of existing device struct and gridwise struct for batched_gemm_softmax_gemm and disable the softmax part. Still based on old wmma pipelines. Also copy the example and remove the softmax part from the reference calculation. Works and results match reference except for tiny float errors in problem 2.
* Turn DeviceBatchedGemmGemm_Wmma_CShuffleV3 into a proper DeviceBatchedGemmGemm derived class, with the right argument and invoker functions. Update example to use new definitions.
* Remove unused cross-attention and self-attention kernels, arguments, and invokers. Also remove other unused Argument types.
* Remove masking related code, test unusual sizes in example.
* Remove remaining softmax related code from GridwiseBatchedGemmGemm_wmma_cshuffle_v3 and example.
* Remove code related to numDims, bias, and TensorSpec from Device struct and example.
* Add layout template parameters to device struct
* Move (NPerBlock, LTilePerBlock) device struct template arguments up by two places to match XDL template argument ordering.
* Merge accumulation data types into one type to match XDL device struct.
* Remove NPerWmma template parameter from device struct and just set it equal to LPerWmma. Now device struct template params exactly match those for XDL batched gemm gemm.
* Add support for RCCR layout and test this in example
* Add batched_gemm_gemm_wmma to instance library + profiler, and add gtest just like for xdl.
* Add RCCR instance and additional RCRR instance to library.
* Remove unused permute and alpha related code. Time all tests. Fix B1 strides in argument verification.
* Remove references to G0, G1 in favor of batch, reduce dimensionality of length and stride arrays.
* Managed to replace old wmma gridwise pipeline and blockwise struct with new wmma blockwise pipeline. Some cleanup required but all tests pass.
* Make TransposeC a proper template parameter that gets passed all the way from BlockGemmPipeline_Selector to WmmaGemm so we can use the correct settings for bacthed gemm gemm as well as regular gemm. Gemm universal tests now pass again.
* Replace old LoopSched and PipelineVer params with BlockwiseGemm pipeline equivalents, and use these in instance factory. The v3 pipeline does not work yet, but v1 works for intrawave and interwave.
* Adapt the A wave descriptor to deal with RDNA4 wmma. This fixes batched gemm gemm functionality on RDNA4.
* Fixed two aspects of the v3 pipeline that were incorrect: First of all the blockwise copy operator was invoked once too many in all cases (RunRead and move window), which broke batched gemm gemm when the blockwise pipeline was used multiple times. Furthermore we should be using the mainloop (hotloop) for num_k_loop >=2 instead of num_k_loop >=3. Now we can use support any K dimension.
* Remove num prefetch parameter from gridwise struct since we don't use it and it doesn't do anything,
* Remove unused non-lds paths.
* Test and update the IsSupportedArgument() and CheckValidity() functions for all layouts + padding modes and various problem sizes.
* Add a lot of instances to the profiler with various blocksizes and pipelines, all verified.
* Add support for BF16: instance library, tests, and examples.
* Add examples for int8 and fp8, had to add type_convert_sp template specializations for the latter.
* Template the library instance lists and add default padding instances.
* Move memory calculations from the kernel to the Argument contructor. Also actually parse and use the user-provided batch strides.
* Actually parse and use user-provided regular strides.
* More refactor: remove references to multiple dims per dims, and g0 / g1. Also move xdl specific test utils out of generic test util header.
* Small post-rebase-on-develop fix due to bscale-related pipeline changes. All tests rerun + tested bscale and regular gemm.
* Introduce the correct GetCThreadDescriptor function in the blockwise gemm pipelines for the TransposeC=true case. It turns out to be identical for our batched gemm gemm (gemm0) usecases, but could theoretically be different for wmma_gemm instances with smaller-than-4-byte output data size.
* Remove unused NumPrefetch template parameter, we don't need to match the XDL template params one-to-one.
* Implement proper TailNum and HasMainLoop template parameters for the v3 pipeline. Now the Run() function knows at compile time whether there are 1, 2, or more loops in total, and adds or removes sections accordingly. It still uses the blockwise copy operators the correct amount of times.
* Add print lambda with env check and file and func to device and gridwise level compatibility error messages. Also respect compatibility in example script.
* RDNA3 does not support fp8
[ROCm/composable_kernel commit: 7330ec37ee]
This commit is contained in:
committed by
GitHub
parent
c217c0fa93
commit
e27f9a177d
@@ -1,6 +1,12 @@
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
|
||||
|
||||
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_bf16 batched_gemm_gemm_wmma_cshuffle_v3_bf16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_fp8 batched_gemm_gemm_wmma_cshuffle_v3_fp8.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_fp16 batched_gemm_gemm_wmma_cshuffle_v3_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_int8 batched_gemm_gemm_wmma_cshuffle_v3_int8.cpp)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
@@ -0,0 +1,276 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
/*
|
||||
Gemm + Gemm fused operation. Computes C_g_m_n = (A_g_m_k * B0_g_k_l) * B1_g_l_n
|
||||
|------------------|
|
||||
Gemm0
|
||||
|-----------------------------|
|
||||
Gemm1
|
||||
*/
|
||||
|
||||
static constexpr auto PipeSched = ck::BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto PipelineVer = ck::BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// clang-format off
|
||||
// #define CK_MHA_USE_RCCR_LAYOUT
|
||||
#define CK_MHA_USE_WAVE_1
|
||||
// #define CK_MHA_USE_WAVE_2
|
||||
// #define CK_MHA_USE_WAVE_4
|
||||
// #define CK_MHA_USE_WAVE_8
|
||||
|
||||
#ifdef CK_MHA_USE_RCCR_LAYOUT
|
||||
using DeviceMHAFactory =
|
||||
std::tuple<
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Col, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
32,
|
||||
// Gemm 0
|
||||
16, 64, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 4, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 16, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
>;
|
||||
#else
|
||||
using DeviceMHAFactory =
|
||||
std::tuple<
|
||||
#ifdef CK_MHA_USE_WAVE_1
|
||||
// 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
32,
|
||||
// Gemm 0
|
||||
16, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 16, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>,
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
32,
|
||||
// Gemm 0
|
||||
16, 64, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 4, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 16, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
#endif
|
||||
#ifdef CK_MHA_USE_WAVE_2
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
64,
|
||||
// Gemm 0
|
||||
32, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 32, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>,
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
64,
|
||||
// Gemm 0
|
||||
32, 64, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 4, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 32, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
#endif
|
||||
#ifdef CK_MHA_USE_WAVE_4
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
128,
|
||||
// Gemm 0
|
||||
64, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 64, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>,
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
128,
|
||||
// Gemm 0
|
||||
64, 64, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 4, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 64, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
#endif
|
||||
#ifdef CK_MHA_USE_WAVE_8
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
256,
|
||||
// Gemm 0
|
||||
128, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 128, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>,
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
256,
|
||||
// Gemm 0
|
||||
128, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 128, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
#endif
|
||||
>;
|
||||
#endif
|
||||
|
||||
// clang-format on
|
||||
// Ref Gemm0
|
||||
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B0DataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp>;
|
||||
|
||||
// Ref Gemm1
|
||||
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_batched_gemm_gemm_wmma_cshuffle_v3.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool is_supported = ck::is_gfx11_supported() || ck::is_gfx12_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
return run(argc, argv);
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = BF16;
|
||||
using B0DataType = BF16;
|
||||
using B1DataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = BF16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"
|
||||
@@ -0,0 +1,37 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = F16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"
|
||||
@@ -0,0 +1,34 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using B0DataType = ck::f8_t;
|
||||
using B1DataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = ck::f8_t;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"
|
||||
@@ -0,0 +1,34 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = int8_t;
|
||||
using B0DataType = int8_t;
|
||||
using B1DataType = int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CShuffleDataType = int32_t;
|
||||
using CDataType = int8_t;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"
|
||||
@@ -0,0 +1,304 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape for A/B0/B1/C
|
||||
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
|
||||
ck::index_t M = 113;
|
||||
#ifdef CK_MHA_USE_RCCR_LAYOUT
|
||||
ck::index_t N = 480; // Must be multiple of 8 even with padding.
|
||||
#else
|
||||
ck::index_t N = 477;
|
||||
#endif
|
||||
ck::index_t K = 200; // Must be multiple of 8 even with padding.
|
||||
ck::index_t O = 208; // Must be multiple of 8 even with padding.
|
||||
ck::index_t G = 91; // Batch
|
||||
|
||||
float alpha = 1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
G = std::stoi(argv[8]);
|
||||
|
||||
alpha = std::stof(argv[9]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 8: M, N, K, O, G\n");
|
||||
printf("arg9: scale (alpha)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
std::vector<ck::index_t> a_g_m_k_lengths{G, M, K};
|
||||
std::vector<ck::index_t> a_g_m_k_strides{M * K, K, 1}; // A layout [G, M, K]
|
||||
std::vector<ck::index_t> b0_g_n_k_lengths{G, N, K};
|
||||
std::vector<ck::index_t> b0_g_n_k_strides{N * K, K, 1}; // B0 layout [G, N, K]
|
||||
std::vector<ck::index_t> b1_g_o_n_lengths{G, O, N};
|
||||
#ifdef CK_MHA_USE_RCCR_LAYOUT
|
||||
std::vector<ck::index_t> b1_g_o_n_strides{N * O, N, 1}; // B1 layout [G, O, N]
|
||||
#else
|
||||
std::vector<ck::index_t> b1_g_o_n_strides{N * O, 1, O}; // B1 layout [G, N, O]
|
||||
#endif
|
||||
std::vector<ck::index_t> c_g_m_o_lengths{G, M, O};
|
||||
std::vector<ck::index_t> c_g_m_o_strides{M * O, O, 1}; // C layout [G, M, O]
|
||||
|
||||
Tensor<ADataType> a_g_m_k(a_g_m_k_lengths, a_g_m_k_strides);
|
||||
Tensor<B0DataType> b0_g_n_k(b0_g_n_k_lengths, b0_g_n_k_strides);
|
||||
Tensor<B1DataType> b1_g_o_n(b1_g_o_n_lengths, b1_g_o_n_strides);
|
||||
Tensor<CDataType> c_g_m_o_host_result(c_g_m_o_lengths, c_g_m_o_strides);
|
||||
Tensor<CDataType> c_g_m_o_device_result(c_g_m_o_lengths, c_g_m_o_strides);
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_n_k: " << b0_g_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_g_o_n: " << b1_g_o_n.mDesc << std::endl;
|
||||
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 3:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
break;
|
||||
case 4: // A, B0, B1 1
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 5: // Rand: b1 b0; unit: a
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 6: // Rand: a b0 ; unit: B1
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 7: // Rand: a b1 ; unit: b0
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 8: // Rand: a ; unit: b0 b1
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 9: // Rand: b0 ; unit: a b1
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 10: // Rand: b1 ; unit: a b0
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_g_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_g_o_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_g_m_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_g_n_k.mData.data());
|
||||
b1_device_buf.ToDevice(b1_g_o_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{alpha};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
float best_perf = .0;
|
||||
float best_time = .0;
|
||||
int not_pass = 0;
|
||||
std::string best_kernel = "";
|
||||
printf("Verification: %s\n", do_verification ? "ON" : "OFF");
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
|
||||
const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
|
||||
|
||||
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
|
||||
auto gemm = DeviceMHAInstance{};
|
||||
auto invoker_ptr = gemm.MakeInvokerPointer();
|
||||
auto argument_ptr =
|
||||
gemm.MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
G, // Batch,
|
||||
a_g_m_k_strides[1], // StrideA,
|
||||
b0_g_n_k_strides[1], // StrideB0,
|
||||
#ifdef CK_MHA_USE_RCCR_LAYOUT
|
||||
b1_g_o_n_strides[1], // StrideB1,
|
||||
#else
|
||||
b1_g_o_n_strides[2], // StrideB1,
|
||||
#endif
|
||||
c_g_m_o_strides[1], // StrideC,
|
||||
a_g_m_k_strides[0], // BatchStrideA
|
||||
b0_g_n_k_strides[0], // BatchStrideB0
|
||||
b1_g_o_n_strides[0], // BatchStrideB1
|
||||
c_g_m_o_strides[0], // BatchStrideC
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc0_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G;
|
||||
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
|
||||
G;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
if(tflops > best_perf)
|
||||
{
|
||||
best_perf = tflops;
|
||||
best_time = ave_time * 1000;
|
||||
best_kernel = gemm.GetTypeString();
|
||||
}
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
|
||||
|
||||
Tensor<B0DataType> b0_g_k_n({G, K, N});
|
||||
Tensor<B1DataType> b1_g_n_o({G, N, O});
|
||||
Tensor<AccDataType> acc0_g_m_n({G, M, N}); // scratch object after gemm0
|
||||
Tensor<ADataType> a1_g_m_n({G, M, N}); // scratch object after conversion
|
||||
|
||||
// permute
|
||||
b0_g_n_k.ForEach(
|
||||
[&](auto& self, auto idx) { b0_g_k_n(idx[0], idx[2], idx[1]) = self(idx); });
|
||||
b1_g_o_n.ForEach(
|
||||
[&](auto& self, auto idx) { b1_g_n_o(idx[0], idx[2], idx[1]) = self(idx); });
|
||||
|
||||
// gemm 0
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
|
||||
// Passthrough instead of softmax, DOES involve data type conversion.
|
||||
a1_g_m_n(idx) = ck::type_convert<ADataType, AccDataType>(self(idx));
|
||||
});
|
||||
|
||||
// gemm1
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n,
|
||||
b1_g_n_o,
|
||||
c_g_m_o_host_result,
|
||||
PassThrough{},
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// default absolute error and relative error is 0.001
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
|
||||
// when BF16 is taken, set absolute error and relative error to 0.01
|
||||
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
|
||||
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
|
||||
{
|
||||
rtol = 1e-2;
|
||||
atol = 1e-2;
|
||||
}
|
||||
|
||||
bool this_run_verification = ck::utils::check_err(c_g_m_o_device_result.mData,
|
||||
c_g_m_o_host_result.mData,
|
||||
"Error: Incorrect results!",
|
||||
rtol,
|
||||
atol);
|
||||
printf("Verification: %s, Pass: %s\n",
|
||||
do_verification ? "ON" : "OFF",
|
||||
this_run_verification ? "YES" : "NO");
|
||||
|
||||
if(!this_run_verification)
|
||||
{
|
||||
not_pass = 1;
|
||||
printf("%d th MHA instance verification Failed \n", i.value);
|
||||
}
|
||||
}
|
||||
});
|
||||
std::cout << "---------------------------------------------------------------------------------"
|
||||
"-----------"
|
||||
<< std::endl;
|
||||
std::cout << "Problem Size: G: " << G << ", M: " << M << ", N: " << N << ", K: " << K
|
||||
<< ", O: " << O << std::endl;
|
||||
std::cout << "---------------------------------------------------------------------------------"
|
||||
"-----------"
|
||||
<< std::endl;
|
||||
std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time
|
||||
<< " us" << std::endl;
|
||||
std::cout << "---------------------------------------------------------------------------------"
|
||||
"-----------"
|
||||
<< std::endl;
|
||||
return not_pass;
|
||||
}
|
||||
@@ -27,7 +27,8 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
constexpr auto BlockGemmPipeline_Selector()
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
@@ -50,7 +51,8 @@ constexpr auto BlockGemmPipeline_Selector()
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
KPack,
|
||||
TransposeC>{};
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
@@ -72,7 +74,8 @@ constexpr auto BlockGemmPipeline_Selector()
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
KPack,
|
||||
TransposeC>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -277,6 +277,21 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
// transposed WMMA output C' = B' * A'
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
|
||||
{
|
||||
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
|
||||
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
|
||||
|
||||
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
|
||||
// |NThreadPerSubGroup |MAccVgprs
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
|
||||
{
|
||||
|
||||
@@ -31,7 +31,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v1
|
||||
{
|
||||
};
|
||||
@@ -53,7 +54,8 @@ template <index_t BlockSize,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -72,7 +74,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
TransposeC>
|
||||
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -90,8 +93,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
KPack,
|
||||
TransposeC>
|
||||
{
|
||||
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
@@ -110,7 +113,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
|
||||
using Base::A_K1;
|
||||
@@ -329,7 +333,8 @@ template <index_t BlockSize,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -348,7 +353,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
TransposeC>
|
||||
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -366,8 +372,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
KPack,
|
||||
TransposeC>
|
||||
{
|
||||
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
@@ -386,7 +392,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
|
||||
|
||||
@@ -31,7 +31,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v3
|
||||
{
|
||||
};
|
||||
@@ -53,7 +54,8 @@ template <index_t BlockSize,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -72,7 +74,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
TransposeC>
|
||||
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -90,7 +93,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
TransposeC>
|
||||
{
|
||||
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
@@ -109,7 +113,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
|
||||
using Base::A_K1;
|
||||
@@ -128,6 +133,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
|
||||
using Base::
|
||||
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs;
|
||||
|
||||
using Base::a_block_desc_k0_m0_m1_m2_k1;
|
||||
using Base::b_block_desc_k0_n0_n1_n2_k1;
|
||||
@@ -145,8 +152,21 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
ignore = num_loop;
|
||||
return TailNumber::Full;
|
||||
if(BlockHasHotloop(num_loop))
|
||||
{
|
||||
return TailNumber::Full;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(num_loop == 1)
|
||||
{
|
||||
return TailNumber::Odd;
|
||||
}
|
||||
else
|
||||
{
|
||||
return TailNumber::Even;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
@@ -362,12 +382,15 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
// Global prefetch 2, perform when at least 2 loops exist.
|
||||
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
}
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
@@ -379,7 +402,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
// Main body, perform when at least 3 loops exist.
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
@@ -448,10 +471,62 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
|
||||
// Pre-tail, perform when at least 2 loops exist.
|
||||
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// No RunRead or MoveSrcSliceWindow here, already finished them all!
|
||||
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// Tail, always perform.
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
|
||||
@@ -0,0 +1,788 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename DeviceOp, typename GridwiseOp, bool HasMainKBlockLoop, TailNumber TailNum>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
|
||||
const long_index_t b0_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
|
||||
const long_index_t b1_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
|
||||
const long_index_t c_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
|
||||
|
||||
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
|
||||
arg.p_a_grid + a_batch_offset,
|
||||
arg.p_b0_grid + b0_batch_offset,
|
||||
arg.p_b1_grid + b1_batch_offset,
|
||||
arg.p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
arg.b1_grid_desc,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op,
|
||||
arg.b0_element_op,
|
||||
arg.acc_element_op,
|
||||
arg.b1_element_op,
|
||||
arg.c_element_op,
|
||||
arg.block_2_ctile_map);
|
||||
#else
|
||||
ignore = arg;
|
||||
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
}
|
||||
|
||||
// Computes C = A * B0 * B1
|
||||
// MN = MK * KL * LN
|
||||
// ^^^^^^ (Acc0)
|
||||
// ^^^^^^^^^^^ (Acc1)
|
||||
template <typename ALayout,
|
||||
typename B0layout,
|
||||
typename B1Layout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t LPerBlock, // Gemm0NPerBlock
|
||||
ck::index_t KPerBlock, // Gemm0KPerBlock
|
||||
ck::index_t NPerBlock, // Gemm1NPerBlock
|
||||
ck::index_t LTilePerBlock, // Gemm1KPerBlock
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t L1, // B1K1
|
||||
ck::index_t MPerWmma, // Gemm0/1 MPerWmma
|
||||
ck::index_t LPerWmma, // Gemm0/1 NPerWmma
|
||||
ck::index_t MRepeat, // Gemm0/1 MWmmaPerWave or Mrepeat
|
||||
ck::index_t LRepeat, // Gemm0 NWmmaPerWave or Nrepeat
|
||||
ck::index_t NRepeat, // Gemm1 NWmmaPerWave or Nrepeat
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename B0BlockTransferThreadClusterLengths_K0_L_K1,
|
||||
typename B0BlockTransferThreadClusterArrangeOrder,
|
||||
typename B0BlockTransferSrcAccessOrder,
|
||||
ck::index_t B0BlockTransferSrcVectorDim,
|
||||
ck::index_t B0BlockTransferSrcScalarPerVector,
|
||||
ck::index_t B0BlockTransferDstScalarPerVector_K1,
|
||||
bool B0BlockLdsAddExtraL,
|
||||
typename B1BlockTransferThreadClusterLengths_L0_N_L1,
|
||||
typename B1BlockTransferThreadClusterArrangeOrder,
|
||||
typename B1BlockTransferSrcAccessOrder,
|
||||
ck::index_t B1BlockTransferSrcVectorDim,
|
||||
ck::index_t B1BlockTransferSrcScalarPerVector,
|
||||
ck::index_t B1BlockTransferDstScalarPerVector_L1,
|
||||
bool B1BlockLdsAddExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1>
|
||||
struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALayout,
|
||||
B0layout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmGemm_Wmma_CShuffleV3;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
// To match XDL implementation NPerWmma (A.k.a Gemm1 NPerWmma) is set equal
|
||||
// to LPerWmma (A.k.a Gemm0 NPerWmma).
|
||||
static constexpr index_t NPerWmma = LPerWmma;
|
||||
|
||||
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
|
||||
// Transform operator or just not use one at all.
|
||||
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
|
||||
Sequence<1, 1, 1, 1, 1>,
|
||||
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
|
||||
GemmSpec,
|
||||
TensorSpecialization::Default, // ASpec
|
||||
TensorSpecialization::Default, // B0Spec
|
||||
TensorSpecialization::Default, // B1Spec
|
||||
TensorSpecialization::Default>; // CSpec
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
|
||||
const std::array<index_t, 3>& a_g_m_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeAGridDescriptor_AK0_M_AK1(
|
||||
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
|
||||
Number<AK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
|
||||
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
|
||||
Number<BK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
|
||||
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
|
||||
{
|
||||
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
|
||||
Number<L1>{});
|
||||
}
|
||||
|
||||
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
|
||||
using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
|
||||
using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
|
||||
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB0,
|
||||
index_t BatchStrideB1,
|
||||
index_t BatchStrideC)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB0_(BatchStrideB0),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
BatchStrideC_(BatchStrideC)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB0_;
|
||||
index_t BatchStrideB1_;
|
||||
index_t BatchStrideC_;
|
||||
};
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
B0DataType,
|
||||
AccDataType, // Acc0DataType
|
||||
B1DataType,
|
||||
AccDataType, // Acc1DataType
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
// ElementwiseOp Family
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc,
|
||||
B0GridDesc,
|
||||
B1GridDesc,
|
||||
CGridDesc_M_N,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
NPerBlock,
|
||||
LTilePerBlock,
|
||||
L1,
|
||||
MPerWmma,
|
||||
LPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
LRepeat,
|
||||
NRepeat,
|
||||
// ThreadCluster Family
|
||||
BlockSize,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
true,
|
||||
ABlockLdsAddExtraM,
|
||||
B0BlockTransferThreadClusterLengths_K0_L_K1,
|
||||
B0BlockTransferThreadClusterArrangeOrder,
|
||||
B0BlockTransferSrcAccessOrder,
|
||||
B0BlockTransferSrcVectorDim,
|
||||
B0BlockTransferSrcScalarPerVector,
|
||||
B0BlockTransferDstScalarPerVector_K1,
|
||||
true,
|
||||
B0BlockLdsAddExtraL,
|
||||
B1BlockTransferThreadClusterLengths_L0_N_L1,
|
||||
B1BlockTransferThreadClusterArrangeOrder,
|
||||
B1BlockTransferSrcAccessOrder,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferDstScalarPerVector_L1,
|
||||
false,
|
||||
B1BlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
Transform::matrix_padder.PadN,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
|
||||
struct RawArg : public BaseArgument
|
||||
{
|
||||
using arr3 = std::array<ck::index_t, 3>;
|
||||
|
||||
RawArg(const ADataType* p_a_grid_,
|
||||
const B0DataType* p_b0_grid_,
|
||||
const B1DataType* p_b1_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t O_,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB0,
|
||||
index_t StrideB1,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB0,
|
||||
index_t BatchStrideB1,
|
||||
index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op_,
|
||||
B0ElementwiseOperation b0_element_op_,
|
||||
AccElementwiseOperation acc_element_op_,
|
||||
B1ElementwiseOperation b1_element_op_,
|
||||
CElementwiseOperation c_element_op_)
|
||||
: p_a_grid{p_a_grid_},
|
||||
p_b0_grid{p_b0_grid_},
|
||||
p_b1_grid{p_b1_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
O{O_},
|
||||
batch_count{Batch},
|
||||
a_element_op{a_element_op_},
|
||||
b0_element_op{b0_element_op_},
|
||||
acc_element_op{acc_element_op_},
|
||||
b1_element_op{b1_element_op_},
|
||||
c_element_op{c_element_op_},
|
||||
compute_base_ptr_of_batch{BatchStrideA, BatchStrideB0, BatchStrideB1, BatchStrideC}
|
||||
{
|
||||
|
||||
a_g_m_k_lengths = arr3{batch_count, M, K};
|
||||
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
|
||||
|
||||
b0_g_n_k_lengths = arr3{batch_count, N, K};
|
||||
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
|
||||
|
||||
b1_g_o_n_lengths = arr3{batch_count, O, N};
|
||||
b1_g_o_n_strides =
|
||||
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
|
||||
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
|
||||
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
|
||||
|
||||
c_g_m_o_lengths = arr3{batch_count, M, O};
|
||||
c_g_m_o_strides = arr3{BatchStrideC, StrideC, 1}; // C layout [batch_count, M, O]
|
||||
|
||||
a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides);
|
||||
b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides);
|
||||
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
|
||||
c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides);
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
|
||||
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
|
||||
}
|
||||
// Pointers
|
||||
const ADataType* p_a_grid;
|
||||
const B0DataType* p_b0_grid;
|
||||
const B1DataType* p_b1_grid;
|
||||
CDataType* p_c_grid;
|
||||
|
||||
// Raw Problem Size
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t O;
|
||||
index_t batch_count;
|
||||
|
||||
arr3 a_g_m_k_lengths;
|
||||
arr3 a_g_m_k_strides;
|
||||
arr3 b0_g_n_k_lengths;
|
||||
arr3 b0_g_n_k_strides;
|
||||
arr3 b1_g_o_n_lengths;
|
||||
arr3 b1_g_o_n_strides;
|
||||
arr3 c_g_m_o_lengths;
|
||||
arr3 c_g_m_o_strides;
|
||||
|
||||
AElementwiseOperation a_element_op;
|
||||
B0ElementwiseOperation b0_element_op;
|
||||
AccElementwiseOperation acc_element_op;
|
||||
B1ElementwiseOperation b1_element_op;
|
||||
CElementwiseOperation c_element_op;
|
||||
|
||||
// Grid descriptors and other mem calculators
|
||||
AGridDesc a_grid_desc;
|
||||
B0GridDesc b0_grid_desc;
|
||||
B1GridDesc b1_grid_desc;
|
||||
CGridDesc_M_N c_grid_desc_m_n;
|
||||
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map;
|
||||
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg)
|
||||
{
|
||||
// Print lambda with env check and printf() style formmating.
|
||||
const char* curFunc = __func__;
|
||||
auto print = [&curFunc](const char* format, ...) -> void {
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
||||
#endif
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
std::vfprintf(stdout, format, args);
|
||||
va_end(args);
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
print("DeviceOp: Arch err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
|
||||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
|
||||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
print("DeviceOp: gfx 11 does not support fp8\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
print("DeviceOp: Acc0 Type err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: A layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B layout must be Column\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
|
||||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B1 layout must be Column or Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<CLayout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: C layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Other padding modes have not been tested and do not get checked individually.
|
||||
if constexpr(GemmSpec != GemmSpecialization::Default &&
|
||||
GemmSpec != GemmSpecialization::MNKOPadding)
|
||||
{
|
||||
print("Padding mode must be default or MNKO\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Per wmma dimensions not equal to 16 are very untested.
|
||||
if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16)
|
||||
{
|
||||
print("M, L, N per Wmma must be 16\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
arg.b1_grid_desc,
|
||||
arg.c_grid_desc_m_n,
|
||||
arg.block_2_ctile_map))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
|
||||
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
|
||||
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
|
||||
const auto c_extent_lowest = arg.O;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
print("DeviceOp: Data Transfer Vector scalar err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest =
|
||||
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
|
||||
const auto b0_stride_lowest =
|
||||
B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1];
|
||||
const auto b1_stride_lowest =
|
||||
B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1];
|
||||
const auto c_stride_lowest = arg.c_g_m_o_strides[2];
|
||||
|
||||
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
|
||||
c_stride_lowest == 1))
|
||||
{
|
||||
print("DeviceOp: Data Vectorize transfer err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
|
||||
}
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::RawArg;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
|
||||
|
||||
const index_t grid_size = arg.batch_count * M0 * N0;
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
|
||||
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
|
||||
constexpr TailNumber tn = tail_number;
|
||||
|
||||
const auto kernel =
|
||||
kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp, GridwiseOp, has_loop, tn>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
|
||||
};
|
||||
|
||||
bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K);
|
||||
TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K);
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid pipeline version!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b0,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t O,
|
||||
ck::index_t Batch,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB0,
|
||||
ck::index_t StrideB1,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t BatchStrideA,
|
||||
ck::index_t BatchStrideB0,
|
||||
ck::index_t BatchStrideB1,
|
||||
ck::index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<RawArg>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const B0DataType*>(p_b0),
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideB1,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideB1,
|
||||
BatchStrideC,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static constexpr const char* DataTypeToString()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, float>)
|
||||
{
|
||||
return "fp32";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck::half_t>)
|
||||
{
|
||||
return "fp16";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck::bhalf_t>)
|
||||
{
|
||||
return "bf16";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck::f8_t>)
|
||||
{
|
||||
return "fp8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck::bf8_t>)
|
||||
{
|
||||
return "bf8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, int32_t>)
|
||||
{
|
||||
return "int32";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, int8_t>)
|
||||
{
|
||||
return "int8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck::int4_t>)
|
||||
{
|
||||
return "int4";
|
||||
}
|
||||
else
|
||||
{
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemmGemm_Wmma_CShuffleV3"
|
||||
<< "<"
|
||||
<< ALayout::name[0]
|
||||
<< B0layout::name[0]
|
||||
<< B1Layout::name[0]
|
||||
<< CLayout::name[0] << ", "
|
||||
<< "A " << DataTypeToString<ADataType>() << ", "
|
||||
<< "B0 " << DataTypeToString<B0DataType>() << ", "
|
||||
<< "B1 " << DataTypeToString<B1DataType>() << ", "
|
||||
<< "C " << DataTypeToString<CDataType>() << ", "
|
||||
<< "Acc " << DataTypeToString<AccDataType>() << ", "
|
||||
<< "Cshuf " << DataTypeToString<CShuffleDataType>() << ", "
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< LPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< LTilePerBlock << ", "
|
||||
<< L1 << ", "
|
||||
<< getGemmSpecializationString(GemmSpec)
|
||||
<< ">"
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseOp::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -243,6 +243,30 @@ inline __host__ __device__ constexpr half_t type_convert_sp<half_t, int>(int x)
|
||||
return u.fp16;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ constexpr int type_convert_sp<int, f8_t>(f8_t x)
|
||||
{
|
||||
union
|
||||
{
|
||||
f8_t fp8;
|
||||
int int32;
|
||||
} u = {x};
|
||||
|
||||
return u.int32;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ constexpr f8_t type_convert_sp<f8_t, int>(int x)
|
||||
{
|
||||
union
|
||||
{
|
||||
int int32;
|
||||
f8_t fp8;
|
||||
} u = {x};
|
||||
|
||||
return u.fp8;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ constexpr int type_convert_sp<int, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,6 +16,70 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif // CK_ENABLE_BF16
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_WMMA
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
@@ -46,6 +110,8 @@ void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_i
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_XDL
|
||||
template <typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
@@ -86,7 +152,46 @@ struct DeviceOperationInstanceFactory<
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<B0DataType, bhalf_t> &&
|
||||
is_same_v<B1DataType, bhalf_t> && is_same_v<CDataType, bhalf_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Col> && is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_BF16
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
|
||||
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Col> && is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_WMMA
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
|
||||
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
|
||||
{
|
||||
@@ -103,10 +208,11 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_XDL
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_batched_gemm_gemm_instance
|
||||
device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = GemmSpecialization::MNKOPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto PipeVerV1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto PipeVerV3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// gemm0: Acc[g, m, n] = A[g, m, k] * B0[g, k, n]
|
||||
// gemm1: C[g, m, o] = Acc[g, m, n] * B1[g, n, o]
|
||||
// Note that in some cases the "m, o, n" dimensions are referred to as the "gemm1 m, n, k"
|
||||
// dimensions instead!
|
||||
template <GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler PipeSched,
|
||||
BlockGemmPipelineVersion PipeVer>
|
||||
using device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances =
|
||||
std::
|
||||
tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// clang-format off
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmDefault, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmDefault, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmDefault, Intrawave, PipeVerV3>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmPadded, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmPadded, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmPadded, Intrawave, PipeVerV3>{});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = GemmSpecialization::MNKOPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto PipeVerV1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto PipeVerV3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// gemm0: Acc[g, m, n] = A[g, m, k] * B0[g, k, n]
|
||||
// gemm1: C[g, m, o] = Acc[g, m, n] * B1[g, n, o]
|
||||
// Note that in some cases the "m, o, n" dimensions are referred to as the "gemm1 m, n, k"
|
||||
// dimensions instead!
|
||||
template <GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler PipeSched,
|
||||
BlockGemmPipelineVersion PipeVer>
|
||||
using device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances =
|
||||
std::
|
||||
tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// clang-format off
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmDefault, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmDefault, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmDefault, Intrawave, PipeVerV3>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmPadded, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmPadded, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmPadded, Intrawave, PipeVerV3>{});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = GemmSpecialization::MNKOPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto PipeVerV1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto PipeVerV3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// gemm0: Acc[g, m, n] = A[g, m, k] * B0[g, k, n]
|
||||
// gemm1: C[g, m, o] = Acc[g, m, n] * B1[g, n, o]
|
||||
// Note that in some cases the "m, o, n" dimensions are referred to as the "gemm1 m, n, k"
|
||||
// dimensions instead!
|
||||
template <GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler PipeSched,
|
||||
BlockGemmPipelineVersion PipeVer>
|
||||
using device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
|
||||
std::
|
||||
tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 32, 16, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 64, 32, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 64, 32, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 128, 64, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 128, 64, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// clang-format off
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmDefault, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmDefault, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmDefault, Intrawave, PipeVerV3>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmPadded, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmPadded, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmPadded, Intrawave, PipeVerV3>{});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = GemmSpecialization::MNKOPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto PipeVerV1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto PipeVerV3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// gemm0: Acc[g, m, n] = A[g, m, k] * B0[g, k, n]
|
||||
// gemm1: C[g, m, o] = Acc[g, m, n] * B1[g, n, o]
|
||||
// Note that in some cases the "m, o, n" dimensions are referred to as the "gemm1 m, n, k"
|
||||
// dimensions instead!
|
||||
template <GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler PipeSched,
|
||||
BlockGemmPipelineVersion PipeVer>
|
||||
using device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances =
|
||||
std::
|
||||
tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// clang-format off
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmDefault, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmDefault, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmDefault, Intrawave, PipeVerV3>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmPadded, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmPadded, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmPadded, Intrawave, PipeVerV3>{});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -220,9 +220,10 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
int num_supported_instances = 0;
|
||||
|
||||
// profile device op instances
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
@@ -255,6 +256,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
num_supported_instances++;
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float ave_time =
|
||||
@@ -309,6 +311,8 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
|
||||
}
|
||||
}
|
||||
|
||||
printf("\033[36mFound %d supported instances\033[0m\n", num_supported_instances);
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
|
||||
@@ -41,7 +41,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_OPS profile_gemm_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
|
||||
@@ -98,6 +97,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
@@ -155,7 +155,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
|
||||
@@ -219,6 +218,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16_xdl.cpp)
|
||||
add_gtest_executable(test_batched_gemm_gemm_fp16_xdl test_batched_gemm_gemm_fp16_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_custom_target(test_batched_gemm_gemm)
|
||||
target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16)
|
||||
target_link_libraries(test_batched_gemm_gemm_fp16_xdl PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_batched_gemm_gemm_bf16_wmma test_batched_gemm_gemm_bf16_wmma_cshuffle_v3.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_gemm_bf16_wmma PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_batched_gemm_gemm_fp16_wmma test_batched_gemm_gemm_fp16_wmma_cshuffle_v3.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_gemm_fp16_wmma PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_batched_gemm_gemm_util.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmGemmBF16 : public TestBatchedGemmGemm<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<BF16, BF16, BF16, BF16, Row, Col, Row, Row>,
|
||||
std::tuple<BF16, BF16, BF16, BF16, Row, Col, Col, Row>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchedGemmGemmBF16, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16)
|
||||
{
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{136, 128, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 136, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 40, 128, 1},
|
||||
{128, 128, 136, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 136, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{129, 128, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 129, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 33, 128, 1},
|
||||
{128, 128, 129, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
// If kernel B1Layout is RowMajor, expect not to support odd O size
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 129, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, DISABLED_Bench_BF16)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{256, 256, 64, 64, 768},
|
||||
{256, 256, 128, 128, 768},
|
||||
{512, 512, 64, 64, 768},
|
||||
{512, 512, 128, 128, 768},
|
||||
{1024, 1024, 64, 64, 768},
|
||||
{1024, 1024, 128, 128, 768},
|
||||
{2048, 2048, 64, 64, 768},
|
||||
{2048, 2048, 128, 128, 768},
|
||||
{4096, 4096, 64, 64, 768},
|
||||
{4096, 4096, 128, 128, 768},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = false;
|
||||
this->Run();
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_batched_gemm_gemm_util.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row>,
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Col, Row>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchedGemmGemmFP16, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16)
|
||||
{
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{136, 128, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 136, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 40, 128, 1},
|
||||
{128, 128, 136, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 136, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{129, 128, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 129, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 33, 128, 1},
|
||||
{128, 128, 129, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
// If kernel B1Layout is RowMajor, expect not to support odd O size
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 129, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Bench_FP16)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{256, 256, 64, 64, 768},
|
||||
{256, 256, 128, 128, 768},
|
||||
{512, 512, 64, 64, 768},
|
||||
{512, 512, 128, 128, 768},
|
||||
{1024, 1024, 64, 64, 768},
|
||||
{1024, 1024, 128, 128, 768},
|
||||
{2048, 2048, 64, 64, 768},
|
||||
{2048, 2048, 128, 128, 768},
|
||||
{4096, 4096, 64, 64, 768},
|
||||
{4096, 4096, 128, 128, 768},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = false;
|
||||
this->Run();
|
||||
}
|
||||
@@ -1,8 +1,126 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_batched_gemm_gemm_util.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ALayout = Row;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = F16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = PassThrough;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
|
||||
|
||||
using DeviceGemmGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
32, // Gemm1KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
2, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<8, 32, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
bool IsSupported(int M, int N, int K, int O)
|
||||
{
|
||||
auto gemm = DeviceGemmGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
|
||||
static_cast<B0DataType*>(nullptr),
|
||||
static_cast<B1DataType*>(nullptr),
|
||||
static_cast<CDataType*>(nullptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
0, // BatchCount
|
||||
0, // StrideA
|
||||
0, // StrideB0
|
||||
0, // StrideB1
|
||||
0, // StrideC
|
||||
0, // BatchStrideA
|
||||
0, // BatchStrideB0
|
||||
0, // BatchStrideB1
|
||||
0, // BatchStrideC
|
||||
PassThrough{}, // a_element_op
|
||||
PassThrough{}, // b0_element_op
|
||||
PassThrough{}, // acc0_element_op
|
||||
PassThrough{}, // b1_element_op
|
||||
PassThrough{}); // c_element_op
|
||||
|
||||
return gemm.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm<Tuple>
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <vector>
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
#include "profiler/profile_batched_gemm_gemm_impl.hpp"
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
@@ -13,7 +12,8 @@ using ck::tensor_operation::device::GemmSpecialization;
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -70,120 +70,3 @@ struct TestBatchedGemmGemm : public ::testing::Test
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ALayout = Row;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = F16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = PassThrough;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
|
||||
|
||||
using DeviceGemmGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
32, // Gemm1KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
2, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<8, 32, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
bool IsSupported(int M, int N, int K, int O)
|
||||
{
|
||||
auto gemm = DeviceGemmGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
|
||||
static_cast<B0DataType*>(nullptr),
|
||||
static_cast<B1DataType*>(nullptr),
|
||||
static_cast<CDataType*>(nullptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
0, // BatchCount
|
||||
0, // StrideA
|
||||
0, // StrideB0
|
||||
0, // StrideB1
|
||||
0, // StrideC
|
||||
0, // BatchStrideA
|
||||
0, // BatchStrideB0
|
||||
0, // BatchStrideB1
|
||||
0, // BatchStrideC
|
||||
PassThrough{}, // a_element_op
|
||||
PassThrough{}, // b0_element_op
|
||||
PassThrough{}, // acc0_element_op
|
||||
PassThrough{}, // b1_element_op
|
||||
PassThrough{}); // c_element_op
|
||||
|
||||
return gemm.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user