mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit '7330ec37ee3b8cf2d54630372dfe9e86a893e4f5' into develop
This commit is contained in:
11
Jenkinsfile
vendored
11
Jenkinsfile
vendored
@@ -855,13 +855,20 @@ def run_aiter_tests(Map conf=[:]){
|
||||
}
|
||||
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
timeout(time: 45, unit: 'MINUTES'){
|
||||
timeout(time: 2, unit: 'HOURS'){
|
||||
try{
|
||||
sh "rocminfo"
|
||||
sh "python3 --version"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_gemm_a8w8_blockscale.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_mha.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_2stage.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_blockscale.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_ep.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_sorting.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_sorting_mxfp4.py"
|
||||
sh "python3 /home/jenkins/workspace/aiter/op_tests/test_moe_tkw1.py"
|
||||
}
|
||||
catch(e){
|
||||
echo "Throwing error exception while running AITER tests"
|
||||
@@ -906,7 +913,7 @@ def run_pytorch_tests(Map conf=[:]){
|
||||
}
|
||||
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
timeout(time: 45, unit: 'MINUTES'){
|
||||
timeout(time: 2, unit: 'HOURS'){
|
||||
try{
|
||||
sh "rocminfo"
|
||||
sh "python3 --version"
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
|
||||
|
||||
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_bf16 batched_gemm_gemm_wmma_cshuffle_v3_bf16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_fp8 batched_gemm_gemm_wmma_cshuffle_v3_fp8.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_fp16 batched_gemm_gemm_wmma_cshuffle_v3_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_wmma_cshuffle_v3_int8 batched_gemm_gemm_wmma_cshuffle_v3_int8.cpp)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
@@ -0,0 +1,276 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
/*
|
||||
Gemm + Gemm fused operation. Computes C_g_m_n = (A_g_m_k * B0_g_k_l) * B1_g_l_n
|
||||
|------------------|
|
||||
Gemm0
|
||||
|-----------------------------|
|
||||
Gemm1
|
||||
*/
|
||||
|
||||
static constexpr auto PipeSched = ck::BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto PipelineVer = ck::BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// clang-format off
|
||||
// #define CK_MHA_USE_RCCR_LAYOUT
|
||||
#define CK_MHA_USE_WAVE_1
|
||||
// #define CK_MHA_USE_WAVE_2
|
||||
// #define CK_MHA_USE_WAVE_4
|
||||
// #define CK_MHA_USE_WAVE_8
|
||||
|
||||
#ifdef CK_MHA_USE_RCCR_LAYOUT
|
||||
using DeviceMHAFactory =
|
||||
std::tuple<
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Col, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
32,
|
||||
// Gemm 0
|
||||
16, 64, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 4, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 16, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
>;
|
||||
#else
|
||||
using DeviceMHAFactory =
|
||||
std::tuple<
|
||||
#ifdef CK_MHA_USE_WAVE_1
|
||||
// 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
32,
|
||||
// Gemm 0
|
||||
16, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 16, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>,
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
32,
|
||||
// Gemm 0
|
||||
16, 64, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 4, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 16, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
#endif
|
||||
#ifdef CK_MHA_USE_WAVE_2
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
64,
|
||||
// Gemm 0
|
||||
32, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 32, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>,
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
64,
|
||||
// Gemm 0
|
||||
32, 64, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 4, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 32, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
#endif
|
||||
#ifdef CK_MHA_USE_WAVE_4
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
128,
|
||||
// Gemm 0
|
||||
64, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 64, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>,
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
128,
|
||||
// Gemm 0
|
||||
64, 64, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 4, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 64, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
#endif
|
||||
#ifdef CK_MHA_USE_WAVE_8
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
256,
|
||||
// Gemm 0
|
||||
128, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 128, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>,
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Wmma_CShuffleV3<
|
||||
Row, Col, Row, Row,
|
||||
ADataType, B0DataType, B1DataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp,
|
||||
GemmSpec,
|
||||
256,
|
||||
// Gemm 0
|
||||
128, 128, 64, 64, 64, 8, 8,
|
||||
// Gemm 1
|
||||
8,
|
||||
16, 16,
|
||||
// Per repeat = wave_m = wave_num, wave_n = 1
|
||||
1, 8, 4,
|
||||
// ABlockTransfer MK -> K0 M K1
|
||||
S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B0BlockTransfer LK -> K0 L K1
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true,
|
||||
// B1BlockTransfer NL -> L0 N L1
|
||||
S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false,
|
||||
// CShuffleBlockTransfer MN
|
||||
1, 1, S<1, 128, 1, 2>, 8,
|
||||
PipeSched, PipelineVer>
|
||||
#endif
|
||||
>;
|
||||
#endif
|
||||
|
||||
// clang-format on
|
||||
// Ref Gemm0
|
||||
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B0DataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp>;
|
||||
|
||||
// Ref Gemm1
|
||||
using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_batched_gemm_gemm_wmma_cshuffle_v3.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool is_supported = ck::is_gfx11_supported() || ck::is_gfx12_supported();
|
||||
if(!is_supported)
|
||||
{
|
||||
std::cout << "WARNING: wmma example not supported on the platform " << ck::get_device_name()
|
||||
<< std::endl;
|
||||
return 0;
|
||||
}
|
||||
return run(argc, argv);
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = BF16;
|
||||
using B0DataType = BF16;
|
||||
using B1DataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = BF16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"
|
||||
@@ -0,0 +1,37 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using CDataType = F16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"
|
||||
@@ -0,0 +1,34 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using B0DataType = ck::f8_t;
|
||||
using B1DataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = ck::f8_t;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"
|
||||
@@ -0,0 +1,34 @@
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = int8_t;
|
||||
using B0DataType = int8_t;
|
||||
using B1DataType = int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CShuffleDataType = int32_t;
|
||||
using CDataType = int8_t;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
#include "batched_gemm_gemm_wmma_cshuffle_v3_base.inc"
|
||||
@@ -0,0 +1,304 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
int run(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape for A/B0/B1/C
|
||||
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
|
||||
ck::index_t M = 113;
|
||||
#ifdef CK_MHA_USE_RCCR_LAYOUT
|
||||
ck::index_t N = 480; // Must be multiple of 8 even with padding.
|
||||
#else
|
||||
ck::index_t N = 477;
|
||||
#endif
|
||||
ck::index_t K = 200; // Must be multiple of 8 even with padding.
|
||||
ck::index_t O = 208; // Must be multiple of 8 even with padding.
|
||||
ck::index_t G = 91; // Batch
|
||||
|
||||
float alpha = 1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
G = std::stoi(argv[8]);
|
||||
|
||||
alpha = std::stof(argv[9]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 8: M, N, K, O, G\n");
|
||||
printf("arg9: scale (alpha)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
std::vector<ck::index_t> a_g_m_k_lengths{G, M, K};
|
||||
std::vector<ck::index_t> a_g_m_k_strides{M * K, K, 1}; // A layout [G, M, K]
|
||||
std::vector<ck::index_t> b0_g_n_k_lengths{G, N, K};
|
||||
std::vector<ck::index_t> b0_g_n_k_strides{N * K, K, 1}; // B0 layout [G, N, K]
|
||||
std::vector<ck::index_t> b1_g_o_n_lengths{G, O, N};
|
||||
#ifdef CK_MHA_USE_RCCR_LAYOUT
|
||||
std::vector<ck::index_t> b1_g_o_n_strides{N * O, N, 1}; // B1 layout [G, O, N]
|
||||
#else
|
||||
std::vector<ck::index_t> b1_g_o_n_strides{N * O, 1, O}; // B1 layout [G, N, O]
|
||||
#endif
|
||||
std::vector<ck::index_t> c_g_m_o_lengths{G, M, O};
|
||||
std::vector<ck::index_t> c_g_m_o_strides{M * O, O, 1}; // C layout [G, M, O]
|
||||
|
||||
Tensor<ADataType> a_g_m_k(a_g_m_k_lengths, a_g_m_k_strides);
|
||||
Tensor<B0DataType> b0_g_n_k(b0_g_n_k_lengths, b0_g_n_k_strides);
|
||||
Tensor<B1DataType> b1_g_o_n(b1_g_o_n_lengths, b1_g_o_n_strides);
|
||||
Tensor<CDataType> c_g_m_o_host_result(c_g_m_o_lengths, c_g_m_o_strides);
|
||||
Tensor<CDataType> c_g_m_o_device_result(c_g_m_o_lengths, c_g_m_o_strides);
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_n_k: " << b0_g_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_g_o_n: " << b1_g_o_n.mDesc << std::endl;
|
||||
std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 2:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
break;
|
||||
case 3:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
break;
|
||||
case 4: // A, B0, B1 1
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 5: // Rand: b1 b0; unit: a
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 6: // Rand: a b0 ; unit: B1
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 7: // Rand: a b1 ; unit: b0
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
case 8: // Rand: a ; unit: b0 b1
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 9: // Rand: b0 ; unit: a b1
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 10: // Rand: b1 ; unit: a b0
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 2});
|
||||
break;
|
||||
default:
|
||||
a_g_m_k.GenerateTensorValue(GeneratorTensor_Sequential<ADataType, 2>{});
|
||||
b0_g_n_k.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
|
||||
b1_g_o_n.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_g_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_g_o_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_g_m_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_g_n_k.mData.data());
|
||||
b1_device_buf.ToDevice(b1_g_o_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto acc0_element_op = Acc0ElementOp{alpha};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
float best_perf = .0;
|
||||
float best_time = .0;
|
||||
int not_pass = 0;
|
||||
std::string best_kernel = "";
|
||||
printf("Verification: %s\n", do_verification ? "ON" : "OFF");
|
||||
|
||||
ck::static_for<0, std::tuple_size_v<DeviceMHAFactory>, 1>{}([&](auto i) -> void {
|
||||
const auto device_mha_instance = std::get<i>(DeviceMHAFactory{});
|
||||
|
||||
using DeviceMHAInstance = ck::remove_cvref_t<decltype(device_mha_instance)>;
|
||||
auto gemm = DeviceMHAInstance{};
|
||||
auto invoker_ptr = gemm.MakeInvokerPointer();
|
||||
auto argument_ptr =
|
||||
gemm.MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B1DataType*>(b1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
G, // Batch,
|
||||
a_g_m_k_strides[1], // StrideA,
|
||||
b0_g_n_k_strides[1], // StrideB0,
|
||||
#ifdef CK_MHA_USE_RCCR_LAYOUT
|
||||
b1_g_o_n_strides[1], // StrideB1,
|
||||
#else
|
||||
b1_g_o_n_strides[2], // StrideB1,
|
||||
#endif
|
||||
c_g_m_o_strides[1], // StrideC,
|
||||
a_g_m_k_strides[0], // BatchStrideA
|
||||
b0_g_n_k_strides[0], // BatchStrideB0
|
||||
b1_g_o_n_strides[0], // BatchStrideB1
|
||||
c_g_m_o_strides[0], // BatchStrideC
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc0_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G;
|
||||
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
|
||||
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) *
|
||||
G;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
if(tflops > best_perf)
|
||||
{
|
||||
best_perf = tflops;
|
||||
best_time = ave_time * 1000;
|
||||
best_kernel = gemm.GetTypeString();
|
||||
}
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_g_m_o_device_result.mData.data());
|
||||
|
||||
Tensor<B0DataType> b0_g_k_n({G, K, N});
|
||||
Tensor<B1DataType> b1_g_n_o({G, N, O});
|
||||
Tensor<AccDataType> acc0_g_m_n({G, M, N}); // scratch object after gemm0
|
||||
Tensor<ADataType> a1_g_m_n({G, M, N}); // scratch object after conversion
|
||||
|
||||
// permute
|
||||
b0_g_n_k.ForEach(
|
||||
[&](auto& self, auto idx) { b0_g_k_n(idx[0], idx[2], idx[1]) = self(idx); });
|
||||
b1_g_o_n.ForEach(
|
||||
[&](auto& self, auto idx) { b1_g_n_o(idx[0], idx[2], idx[1]) = self(idx); });
|
||||
|
||||
// gemm 0
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op);
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
|
||||
// Passthrough instead of softmax, DOES involve data type conversion.
|
||||
a1_g_m_n(idx) = ck::type_convert<ADataType, AccDataType>(self(idx));
|
||||
});
|
||||
|
||||
// gemm1
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n,
|
||||
b1_g_n_o,
|
||||
c_g_m_o_host_result,
|
||||
PassThrough{},
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// default absolute error and relative error is 0.001
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
|
||||
// when BF16 is taken, set absolute error and relative error to 0.01
|
||||
if(std::is_same_v<ADataType, ck::bhalf_t> && std::is_same_v<B0DataType, ck::bhalf_t> &&
|
||||
std::is_same_v<B1DataType, ck::bhalf_t> && std::is_same_v<CDataType, ck::bhalf_t>)
|
||||
{
|
||||
rtol = 1e-2;
|
||||
atol = 1e-2;
|
||||
}
|
||||
|
||||
bool this_run_verification = ck::utils::check_err(c_g_m_o_device_result.mData,
|
||||
c_g_m_o_host_result.mData,
|
||||
"Error: Incorrect results!",
|
||||
rtol,
|
||||
atol);
|
||||
printf("Verification: %s, Pass: %s\n",
|
||||
do_verification ? "ON" : "OFF",
|
||||
this_run_verification ? "YES" : "NO");
|
||||
|
||||
if(!this_run_verification)
|
||||
{
|
||||
not_pass = 1;
|
||||
printf("%d th MHA instance verification Failed \n", i.value);
|
||||
}
|
||||
}
|
||||
});
|
||||
std::cout << "---------------------------------------------------------------------------------"
|
||||
"-----------"
|
||||
<< std::endl;
|
||||
std::cout << "Problem Size: G: " << G << ", M: " << M << ", N: " << N << ", K: " << K
|
||||
<< ", O: " << O << std::endl;
|
||||
std::cout << "---------------------------------------------------------------------------------"
|
||||
"-----------"
|
||||
<< std::endl;
|
||||
std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time
|
||||
<< " us" << std::endl;
|
||||
std::cout << "---------------------------------------------------------------------------------"
|
||||
"-----------"
|
||||
<< std::endl;
|
||||
return not_pass;
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/library/include
|
||||
${PROJECT_SOURCE_DIR}/example/include
|
||||
)
|
||||
|
||||
add_custom_target(examples)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
#include "mask.hpp"
|
||||
#include "rotary.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include <cstring>
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "topk_softmax_api.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#if 0
|
||||
template <typename T>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "rmsnorm2d_fwd.hpp"
|
||||
#include <cstring>
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "add_rmsnorm2d_rdquant_fwd.hpp"
|
||||
#include <cstring>
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename InputDataType>
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "smoothquant.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include <cstring>
|
||||
|
||||
// different threshold for different dtype
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "moe_sorting_api.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "moe_smoothquant.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include <cstring>
|
||||
#include <set>
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#include <set>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include "fused_moe.hpp"
|
||||
|
||||
// different threshold for different dtype
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include <json_dump.hpp>
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
#include <type_traits>
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
template <typename T>
|
||||
constexpr const char* DataTypeToString()
|
||||
{
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include "elementwise_common.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include "elementwise_common.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_transpose.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include "elementwise_common.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#include "elementwise_common.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
|
||||
#include "batched_transpose_example.hpp"
|
||||
|
||||
#include "json_dump.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
#if 0
|
||||
template <typename T>
|
||||
void dump_host_tensor_4d(const ck_tile::HostTensor<T>& x)
|
||||
|
||||
@@ -27,7 +27,8 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
constexpr auto BlockGemmPipeline_Selector()
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
@@ -50,7 +51,8 @@ constexpr auto BlockGemmPipeline_Selector()
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
KPack,
|
||||
TransposeC>{};
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
@@ -72,7 +74,8 @@ constexpr auto BlockGemmPipeline_Selector()
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
KPack,
|
||||
TransposeC>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -277,6 +277,21 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
// transposed WMMA output C' = B' * A'
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
|
||||
{
|
||||
constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
|
||||
wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
|
||||
|
||||
constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
|
||||
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
|
||||
// |NThreadPerSubGroup |MAccVgprs
|
||||
make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
|
||||
{
|
||||
|
||||
@@ -31,7 +31,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v1
|
||||
{
|
||||
};
|
||||
@@ -53,7 +54,8 @@ template <index_t BlockSize,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -72,7 +74,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
TransposeC>
|
||||
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -90,8 +93,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
KPack,
|
||||
TransposeC>
|
||||
{
|
||||
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
@@ -110,7 +113,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
|
||||
using Base::A_K1;
|
||||
@@ -329,7 +333,8 @@ template <index_t BlockSize,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -348,7 +353,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
TransposeC>
|
||||
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -366,8 +372,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
KPack,
|
||||
TransposeC>
|
||||
{
|
||||
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
@@ -386,7 +392,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
|
||||
|
||||
@@ -31,7 +31,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v3
|
||||
{
|
||||
};
|
||||
@@ -53,7 +54,8 @@ template <index_t BlockSize,
|
||||
index_t NPerWmma,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
index_t KPack,
|
||||
bool TransposeC>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -72,7 +74,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
TransposeC>
|
||||
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -90,7 +93,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
KPack,
|
||||
TransposeC>
|
||||
{
|
||||
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
@@ -109,7 +113,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
KPack,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
|
||||
using Base::A_K1;
|
||||
@@ -128,6 +133,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::
|
||||
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
|
||||
using Base::
|
||||
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs;
|
||||
|
||||
using Base::a_block_desc_k0_m0_m1_m2_k1;
|
||||
using Base::b_block_desc_k0_n0_n1_n2_k1;
|
||||
@@ -145,8 +152,21 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
|
||||
__host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
ignore = num_loop;
|
||||
return TailNumber::Full;
|
||||
if(BlockHasHotloop(num_loop))
|
||||
{
|
||||
return TailNumber::Full;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(num_loop == 1)
|
||||
{
|
||||
return TailNumber::Odd;
|
||||
}
|
||||
else
|
||||
{
|
||||
return TailNumber::Even;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto HotLoopScheduler()
|
||||
@@ -362,12 +382,15 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
// Global prefetch 2, perform when at least 2 loops exist.
|
||||
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
}
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
@@ -379,7 +402,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
// Main body, perform when at least 3 loops exist.
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
@@ -448,10 +471,62 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
|
||||
// Pre-tail, perform when at least 2 loops exist.
|
||||
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// No RunRead or MoveSrcSliceWindow here, already finished them all!
|
||||
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// Tail, always perform.
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
|
||||
@@ -0,0 +1,788 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename DeviceOp, typename GridwiseOp, bool HasMainKBlockLoop, TailNumber TailNum>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
|
||||
|
||||
__shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx)));
|
||||
const long_index_t b0_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx)));
|
||||
const long_index_t b1_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
|
||||
const long_index_t c_batch_offset =
|
||||
__builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
|
||||
|
||||
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
|
||||
arg.p_a_grid + a_batch_offset,
|
||||
arg.p_b0_grid + b0_batch_offset,
|
||||
arg.p_b1_grid + b1_batch_offset,
|
||||
arg.p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
arg.b1_grid_desc,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op,
|
||||
arg.b0_element_op,
|
||||
arg.acc_element_op,
|
||||
arg.b1_element_op,
|
||||
arg.c_element_op,
|
||||
arg.block_2_ctile_map);
|
||||
#else
|
||||
ignore = arg;
|
||||
#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
}
|
||||
|
||||
// Computes C = A * B0 * B1
|
||||
// MN = MK * KL * LN
|
||||
// ^^^^^^ (Acc0)
|
||||
// ^^^^^^^^^^^ (Acc1)
|
||||
template <typename ALayout,
|
||||
typename B0layout,
|
||||
typename B1Layout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename B0DataType,
|
||||
typename B1DataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t LPerBlock, // Gemm0NPerBlock
|
||||
ck::index_t KPerBlock, // Gemm0KPerBlock
|
||||
ck::index_t NPerBlock, // Gemm1NPerBlock
|
||||
ck::index_t LTilePerBlock, // Gemm1KPerBlock
|
||||
ck::index_t AK1,
|
||||
ck::index_t BK1,
|
||||
ck::index_t L1, // B1K1
|
||||
ck::index_t MPerWmma, // Gemm0/1 MPerWmma
|
||||
ck::index_t LPerWmma, // Gemm0/1 NPerWmma
|
||||
ck::index_t MRepeat, // Gemm0/1 MWmmaPerWave or Mrepeat
|
||||
ck::index_t LRepeat, // Gemm0 NWmmaPerWave or Nrepeat
|
||||
ck::index_t NRepeat, // Gemm1 NWmmaPerWave or Nrepeat
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename B0BlockTransferThreadClusterLengths_K0_L_K1,
|
||||
typename B0BlockTransferThreadClusterArrangeOrder,
|
||||
typename B0BlockTransferSrcAccessOrder,
|
||||
ck::index_t B0BlockTransferSrcVectorDim,
|
||||
ck::index_t B0BlockTransferSrcScalarPerVector,
|
||||
ck::index_t B0BlockTransferDstScalarPerVector_K1,
|
||||
bool B0BlockLdsAddExtraL,
|
||||
typename B1BlockTransferThreadClusterLengths_L0_N_L1,
|
||||
typename B1BlockTransferThreadClusterArrangeOrder,
|
||||
typename B1BlockTransferSrcAccessOrder,
|
||||
ck::index_t B1BlockTransferSrcVectorDim,
|
||||
ck::index_t B1BlockTransferSrcScalarPerVector,
|
||||
ck::index_t B1BlockTransferDstScalarPerVector_L1,
|
||||
bool B1BlockLdsAddExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1>
|
||||
struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALayout,
|
||||
B0layout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmGemm_Wmma_CShuffleV3;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
// To match XDL implementation NPerWmma (A.k.a Gemm1 NPerWmma) is set equal
|
||||
// to LPerWmma (A.k.a Gemm0 NPerWmma).
|
||||
static constexpr index_t NPerWmma = LPerWmma;
|
||||
|
||||
// TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler
|
||||
// Transform operator or just not use one at all.
|
||||
using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma<
|
||||
Sequence<1, 1, 1, 1, 1>,
|
||||
Sequence<MPerBlock, LPerBlock, KPerBlock, NPerBlock>,
|
||||
GemmSpec,
|
||||
TensorSpecialization::Default, // ASpec
|
||||
TensorSpecialization::Default, // B0Spec
|
||||
TensorSpecialization::Default, // B1Spec
|
||||
TensorSpecialization::Default>; // CSpec
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeAGridDescriptor(const std::array<index_t, 3>& a_g_m_k_lengths_vec,
|
||||
const std::array<index_t, 3>& a_g_m_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeAGridDescriptor_AK0_M_AK1(
|
||||
Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec),
|
||||
Number<AK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB0GridDescriptor(const std::array<index_t, 3>& b0_g_l_k_lengths_vec,
|
||||
const std::array<index_t, 3>& b0_g_l_k_strides_vec)
|
||||
{
|
||||
return Transform::MakeB0GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec),
|
||||
Number<BK1>{});
|
||||
}
|
||||
|
||||
__host__ __device__ static auto
|
||||
MakeB1GridDescriptor(const std::array<index_t, 3>& b1_g_n_l_lengths_vec,
|
||||
const std::array<index_t, 3>& b1_g_n_l_strides_vec)
|
||||
{
|
||||
return Transform::MakeB1GridDescriptor_BK0_N_BK1(
|
||||
Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec),
|
||||
Number<L1>{});
|
||||
}
|
||||
|
||||
using AGridDesc = decltype(MakeAGridDescriptor({}, {}));
|
||||
using B0GridDesc = decltype(MakeB0GridDescriptor({}, {}));
|
||||
using B1GridDesc = decltype(MakeB1GridDescriptor({}, {}));
|
||||
using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {}));
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB0,
|
||||
index_t BatchStrideB1,
|
||||
index_t BatchStrideC)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB0_(BatchStrideB0),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
BatchStrideC_(BatchStrideC)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB0_;
|
||||
index_t BatchStrideB1_;
|
||||
index_t BatchStrideC_;
|
||||
};
|
||||
|
||||
// GridwiseOp
|
||||
using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
B0DataType,
|
||||
AccDataType, // Acc0DataType
|
||||
B1DataType,
|
||||
AccDataType, // Acc1DataType
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
// ElementwiseOp Family
|
||||
AElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc,
|
||||
B0GridDesc,
|
||||
B1GridDesc,
|
||||
CGridDesc_M_N,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
LPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
NPerBlock,
|
||||
LTilePerBlock,
|
||||
L1,
|
||||
MPerWmma,
|
||||
LPerWmma,
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
LRepeat,
|
||||
NRepeat,
|
||||
// ThreadCluster Family
|
||||
BlockSize,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
true,
|
||||
ABlockLdsAddExtraM,
|
||||
B0BlockTransferThreadClusterLengths_K0_L_K1,
|
||||
B0BlockTransferThreadClusterArrangeOrder,
|
||||
B0BlockTransferSrcAccessOrder,
|
||||
B0BlockTransferSrcVectorDim,
|
||||
B0BlockTransferSrcScalarPerVector,
|
||||
B0BlockTransferDstScalarPerVector_K1,
|
||||
true,
|
||||
B0BlockLdsAddExtraL,
|
||||
B1BlockTransferThreadClusterLengths_L0_N_L1,
|
||||
B1BlockTransferThreadClusterArrangeOrder,
|
||||
B1BlockTransferSrcAccessOrder,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferDstScalarPerVector_L1,
|
||||
false,
|
||||
B1BlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
Transform::matrix_padder.PadN,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer>;
|
||||
|
||||
struct RawArg : public BaseArgument
|
||||
{
|
||||
using arr3 = std::array<ck::index_t, 3>;
|
||||
|
||||
RawArg(const ADataType* p_a_grid_,
|
||||
const B0DataType* p_b0_grid_,
|
||||
const B1DataType* p_b1_grid_,
|
||||
CDataType* p_c_grid_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t O_,
|
||||
index_t Batch,
|
||||
index_t StrideA,
|
||||
index_t StrideB0,
|
||||
index_t StrideB1,
|
||||
index_t StrideC,
|
||||
index_t BatchStrideA,
|
||||
index_t BatchStrideB0,
|
||||
index_t BatchStrideB1,
|
||||
index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op_,
|
||||
B0ElementwiseOperation b0_element_op_,
|
||||
AccElementwiseOperation acc_element_op_,
|
||||
B1ElementwiseOperation b1_element_op_,
|
||||
CElementwiseOperation c_element_op_)
|
||||
: p_a_grid{p_a_grid_},
|
||||
p_b0_grid{p_b0_grid_},
|
||||
p_b1_grid{p_b1_grid_},
|
||||
p_c_grid{p_c_grid_},
|
||||
M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
O{O_},
|
||||
batch_count{Batch},
|
||||
a_element_op{a_element_op_},
|
||||
b0_element_op{b0_element_op_},
|
||||
acc_element_op{acc_element_op_},
|
||||
b1_element_op{b1_element_op_},
|
||||
c_element_op{c_element_op_},
|
||||
compute_base_ptr_of_batch{BatchStrideA, BatchStrideB0, BatchStrideB1, BatchStrideC}
|
||||
{
|
||||
|
||||
a_g_m_k_lengths = arr3{batch_count, M, K};
|
||||
a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K]
|
||||
|
||||
b0_g_n_k_lengths = arr3{batch_count, N, K};
|
||||
b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K]
|
||||
|
||||
b1_g_o_n_lengths = arr3{batch_count, O, N};
|
||||
b1_g_o_n_strides =
|
||||
is_same_v<B1Layout, tensor_layout::gemm::RowMajor>
|
||||
? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O]
|
||||
: arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N]
|
||||
|
||||
c_g_m_o_lengths = arr3{batch_count, M, O};
|
||||
c_g_m_o_strides = arr3{BatchStrideC, StrideC, 1}; // C layout [batch_count, M, O]
|
||||
|
||||
a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides);
|
||||
b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides);
|
||||
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
|
||||
c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides);
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
|
||||
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
|
||||
}
|
||||
// Pointers
|
||||
const ADataType* p_a_grid;
|
||||
const B0DataType* p_b0_grid;
|
||||
const B1DataType* p_b1_grid;
|
||||
CDataType* p_c_grid;
|
||||
|
||||
// Raw Problem Size
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t O;
|
||||
index_t batch_count;
|
||||
|
||||
arr3 a_g_m_k_lengths;
|
||||
arr3 a_g_m_k_strides;
|
||||
arr3 b0_g_n_k_lengths;
|
||||
arr3 b0_g_n_k_strides;
|
||||
arr3 b1_g_o_n_lengths;
|
||||
arr3 b1_g_o_n_strides;
|
||||
arr3 c_g_m_o_lengths;
|
||||
arr3 c_g_m_o_strides;
|
||||
|
||||
AElementwiseOperation a_element_op;
|
||||
B0ElementwiseOperation b0_element_op;
|
||||
AccElementwiseOperation acc_element_op;
|
||||
B1ElementwiseOperation b1_element_op;
|
||||
CElementwiseOperation c_element_op;
|
||||
|
||||
// Grid descriptors and other mem calculators
|
||||
AGridDesc a_grid_desc;
|
||||
B0GridDesc b0_grid_desc;
|
||||
B1GridDesc b1_grid_desc;
|
||||
CGridDesc_M_N c_grid_desc_m_n;
|
||||
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map;
|
||||
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg)
|
||||
{
|
||||
// Print lambda with env check and printf() style formmating.
|
||||
const char* curFunc = __func__;
|
||||
auto print = [&curFunc](const char* format, ...) -> void {
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wformat-nonliteral"
|
||||
#endif
|
||||
va_list args;
|
||||
va_start(args, format);
|
||||
std::vfprintf(stdout, format, args);
|
||||
va_end(args);
|
||||
#if defined(__clang__)
|
||||
#pragma clang diagnostic pop
|
||||
#endif
|
||||
std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
print("DeviceOp: Arch err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, f8_t> || std::is_same_v<ADataType, bf8_t> ||
|
||||
std::is_same_v<B0DataType, f8_t> || std::is_same_v<B0DataType, bf8_t> ||
|
||||
std::is_same_v<B1DataType, f8_t> || std::is_same_v<B1DataType, bf8_t>)
|
||||
{
|
||||
if(ck::is_gfx11_supported())
|
||||
{
|
||||
print("DeviceOp: gfx 11 does not support fp8\n");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
print("DeviceOp: Acc0 Type err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<ALayout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: A layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<B0layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B layout must be Column\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<B1Layout, tensor_layout::gemm::RowMajor> ||
|
||||
is_same_v<B1Layout, tensor_layout::gemm::ColumnMajor>))
|
||||
{
|
||||
print("DeviceOp: B1 layout must be Column or Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_same_v<CLayout, tensor_layout::gemm::RowMajor>))
|
||||
{
|
||||
print("DeviceOp: C layout must be Row\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Other padding modes have not been tested and do not get checked individually.
|
||||
if constexpr(GemmSpec != GemmSpecialization::Default &&
|
||||
GemmSpec != GemmSpecialization::MNKOPadding)
|
||||
{
|
||||
print("Padding mode must be default or MNKO\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Per wmma dimensions not equal to 16 are very untested.
|
||||
if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16)
|
||||
{
|
||||
print("M, L, N per Wmma must be 16\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
arg.b1_grid_desc,
|
||||
arg.c_grid_desc_m_n,
|
||||
arg.block_2_ctile_map))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check scalar per vector requirement
|
||||
const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M;
|
||||
const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N;
|
||||
const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O;
|
||||
const auto c_extent_lowest = arg.O;
|
||||
|
||||
if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 &&
|
||||
b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 &&
|
||||
c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
print("DeviceOp: Data Transfer Vector scalar err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check vector load/store requirement
|
||||
const auto a_stride_lowest =
|
||||
ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1];
|
||||
const auto b0_stride_lowest =
|
||||
B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1];
|
||||
const auto b1_stride_lowest =
|
||||
B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1];
|
||||
const auto c_stride_lowest = arg.c_g_m_o_strides[2];
|
||||
|
||||
if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 ||
|
||||
c_stride_lowest == 1))
|
||||
{
|
||||
print("DeviceOp: Data Vectorize transfer err\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const RawArg*>(p_arg));
|
||||
}
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::RawArg;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock);
|
||||
const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock);
|
||||
|
||||
const index_t grid_size = arg.batch_count * M0 * N0;
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) {
|
||||
constexpr bool has_loop = decltype(has_main_k_block_loop)::value;
|
||||
constexpr TailNumber tn = tail_number;
|
||||
|
||||
const auto kernel =
|
||||
kernel_batched_gemm_gemm_wmma_cshuffle_v3<DeviceOp, GridwiseOp, has_loop, tn>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg);
|
||||
};
|
||||
|
||||
bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K);
|
||||
TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K);
|
||||
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(HasMainKBlockLoop && TailNum == TailNumber::Full)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, true>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Full>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Even)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd)
|
||||
{
|
||||
return launch_kernel(std::integral_constant<bool, false>{},
|
||||
std::integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("Invalid pipeline version!\n");
|
||||
return 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b0,
|
||||
const void* p_b1,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t O,
|
||||
ck::index_t Batch,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideB0,
|
||||
ck::index_t StrideB1,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t BatchStrideA,
|
||||
ck::index_t BatchStrideB0,
|
||||
ck::index_t BatchStrideB1,
|
||||
ck::index_t BatchStrideC,
|
||||
AElementwiseOperation a_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
AccElementwiseOperation acc_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<RawArg>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const B0DataType*>(p_b0),
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
Batch,
|
||||
StrideA,
|
||||
StrideB0,
|
||||
StrideB1,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB0,
|
||||
BatchStrideB1,
|
||||
BatchStrideC,
|
||||
a_element_op,
|
||||
b0_element_op,
|
||||
acc_element_op,
|
||||
b1_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static constexpr const char* DataTypeToString()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, float>)
|
||||
{
|
||||
return "fp32";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck::half_t>)
|
||||
{
|
||||
return "fp16";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck::bhalf_t>)
|
||||
{
|
||||
return "bf16";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck::f8_t>)
|
||||
{
|
||||
return "fp8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck::bf8_t>)
|
||||
{
|
||||
return "bf8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, int32_t>)
|
||||
{
|
||||
return "int32";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, int8_t>)
|
||||
{
|
||||
return "int8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck::int4_t>)
|
||||
{
|
||||
return "int4";
|
||||
}
|
||||
else
|
||||
{
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemmGemm_Wmma_CShuffleV3"
|
||||
<< "<"
|
||||
<< ALayout::name[0]
|
||||
<< B0layout::name[0]
|
||||
<< B1Layout::name[0]
|
||||
<< CLayout::name[0] << ", "
|
||||
<< "A " << DataTypeToString<ADataType>() << ", "
|
||||
<< "B0 " << DataTypeToString<B0DataType>() << ", "
|
||||
<< "B1 " << DataTypeToString<B1DataType>() << ", "
|
||||
<< "C " << DataTypeToString<CDataType>() << ", "
|
||||
<< "Acc " << DataTypeToString<AccDataType>() << ", "
|
||||
<< "Cshuf " << DataTypeToString<CShuffleDataType>() << ", "
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< LPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< AK1 << ", "
|
||||
<< BK1 << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< LTilePerBlock << ", "
|
||||
<< L1 << ", "
|
||||
<< getGemmSpecializationString(GemmSpec)
|
||||
<< ">"
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseOp::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -243,6 +243,30 @@ inline __host__ __device__ constexpr half_t type_convert_sp<half_t, int>(int x)
|
||||
return u.fp16;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ constexpr int type_convert_sp<int, f8_t>(f8_t x)
|
||||
{
|
||||
union
|
||||
{
|
||||
f8_t fp8;
|
||||
int int32;
|
||||
} u = {x};
|
||||
|
||||
return u.int32;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ constexpr f8_t type_convert_sp<f8_t, int>(int x)
|
||||
{
|
||||
union
|
||||
{
|
||||
int int32;
|
||||
f8_t fp8;
|
||||
} u = {x};
|
||||
|
||||
return u.fp8;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __host__ __device__ constexpr int type_convert_sp<int, bhalf_t>(bhalf_t x)
|
||||
{
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,6 +16,70 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif // CK_ENABLE_BF16
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_WMMA
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
@@ -46,6 +110,8 @@ void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_i
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_XDL
|
||||
template <typename ALayout,
|
||||
typename B0Layout,
|
||||
typename B1Layout,
|
||||
@@ -86,7 +152,46 @@ struct DeviceOperationInstanceFactory<
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<B0DataType, bhalf_t> &&
|
||||
is_same_v<B1DataType, bhalf_t> && is_same_v<CDataType, bhalf_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Col> && is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_BF16
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
|
||||
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Row> && is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Col> && is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_WMMA
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
|
||||
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
|
||||
{
|
||||
@@ -103,10 +208,11 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_XDL
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
#endif
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_batched_gemm_gemm_instance
|
||||
device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = GemmSpecialization::MNKOPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto PipeVerV1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto PipeVerV3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// gemm0: Acc[g, m, n] = A[g, m, k] * B0[g, k, n]
|
||||
// gemm1: C[g, m, o] = Acc[g, m, n] * B1[g, n, o]
|
||||
// Note that in some cases the "m, o, n" dimensions are referred to as the "gemm1 m, n, k"
|
||||
// dimensions instead!
|
||||
template <GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler PipeSched,
|
||||
BlockGemmPipelineVersion PipeVer>
|
||||
using device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances =
|
||||
std::
|
||||
tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// clang-format off
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmDefault, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmDefault, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmDefault, Intrawave, PipeVerV3>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmPadded, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmPadded, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances<GemmPadded, Intrawave, PipeVerV3>{});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = GemmSpecialization::MNKOPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto PipeVerV1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto PipeVerV3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// gemm0: Acc[g, m, n] = A[g, m, k] * B0[g, k, n]
|
||||
// gemm1: C[g, m, o] = Acc[g, m, n] * B1[g, n, o]
|
||||
// Note that in some cases the "m, o, n" dimensions are referred to as the "gemm1 m, n, k"
|
||||
// dimensions instead!
|
||||
template <GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler PipeSched,
|
||||
BlockGemmPipelineVersion PipeVer>
|
||||
using device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances =
|
||||
std::
|
||||
tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, BF16, BF16, BF16, BF16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// clang-format off
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmDefault, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmDefault, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmDefault, Intrawave, PipeVerV3>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmPadded, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmPadded, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_bf16_bf16_bf16_bf16_gmk_gnk_gon_gmo_instances<GemmPadded, Intrawave, PipeVerV3>{});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = GemmSpecialization::MNKOPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto PipeVerV1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto PipeVerV3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// gemm0: Acc[g, m, n] = A[g, m, k] * B0[g, k, n]
|
||||
// gemm1: C[g, m, o] = Acc[g, m, n] * B1[g, n, o]
|
||||
// Note that in some cases the "m, o, n" dimensions are referred to as the "gemm1 m, n, k"
|
||||
// dimensions instead!
|
||||
template <GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler PipeSched,
|
||||
BlockGemmPipelineVersion PipeVer>
|
||||
using device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
|
||||
std::
|
||||
tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 32, 16, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 64, 32, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 64, 32, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 128, 64, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 128, 64, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Row, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Row,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// clang-format off
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmDefault, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmDefault, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmDefault, Intrawave, PipeVerV3>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmPadded, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmPadded, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances<GemmPadded, Intrawave, PipeVerV3>{});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,92 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmPadded = GemmSpecialization::MNKOPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto PipeVerV1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto PipeVerV3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// gemm0: Acc[g, m, n] = A[g, m, k] * B0[g, k, n]
|
||||
// gemm1: C[g, m, o] = Acc[g, m, n] * B1[g, n, o]
|
||||
// Note that in some cases the "m, o, n" dimensions are referred to as the "gemm1 m, n, k"
|
||||
// dimensions instead!
|
||||
template <GemmSpecialization GemmSpec,
|
||||
BlockGemmPipelineScheduler PipeSched,
|
||||
BlockGemmPipelineVersion PipeVer>
|
||||
using device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances =
|
||||
std::
|
||||
tuple<
|
||||
// clang-format off
|
||||
//################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| BlkGemm| BlkGemm|
|
||||
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| PipeSched| PipelineVer|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 16, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>,
|
||||
DeviceBatchedGemmGemm_Wmma_CShuffleV3< Row, Col, Col, Row, F16, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 64, 64, 8, 8, 8, 16, 16, 1, 8, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 128, 1, 2>, 8, PipeSched, PipeVer>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
|
||||
Col,
|
||||
Col,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// clang-format off
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmDefault, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmDefault, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmDefault, Intrawave, PipeVerV3>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmPadded, Intrawave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmPadded, Interwave, PipeVerV1>{});
|
||||
add_device_operation_instances(instances, device_batched_gemm_gemm_wmma_cshuffle_v3_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances<GemmPadded, Intrawave, PipeVerV3>{});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -220,9 +220,10 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
int num_supported_instances = 0;
|
||||
|
||||
// profile device op instances
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
@@ -255,6 +256,7 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
num_supported_instances++;
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float ave_time =
|
||||
@@ -309,6 +311,8 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
|
||||
}
|
||||
}
|
||||
|
||||
printf("\033[36mFound %d supported instances\033[0m\n", num_supported_instances);
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
|
||||
@@ -41,7 +41,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_OPS profile_gemm_reduce.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_add_relu_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add.cpp)
|
||||
list(APPEND PROFILER_OPS profile_gemm_add_add_fastgelu.cpp)
|
||||
@@ -98,6 +97,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp)
|
||||
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
|
||||
list(APPEND PROFILER_OPS profile_batched_gemm_gemm.cpp)
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
@@ -155,7 +155,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_gemm_streamk_instance)
|
||||
@@ -219,6 +218,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[1
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
|
||||
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16_xdl.cpp)
|
||||
add_gtest_executable(test_batched_gemm_gemm_fp16_xdl test_batched_gemm_gemm_fp16_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_custom_target(test_batched_gemm_gemm)
|
||||
target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16)
|
||||
target_link_libraries(test_batched_gemm_gemm_fp16_xdl PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_batched_gemm_gemm_bf16_wmma test_batched_gemm_gemm_bf16_wmma_cshuffle_v3.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_gemm_bf16_wmma PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_batched_gemm_gemm_fp16_wmma test_batched_gemm_gemm_fp16_wmma_cshuffle_v3.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_gemm_fp16_wmma PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_batched_gemm_gemm_util.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmGemmBF16 : public TestBatchedGemmGemm<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<BF16, BF16, BF16, BF16, Row, Col, Row, Row>,
|
||||
std::tuple<BF16, BF16, BF16, BF16, Row, Col, Col, Row>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchedGemmGemmBF16, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16)
|
||||
{
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{136, 128, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 136, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 40, 128, 1},
|
||||
{128, 128, 136, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 136, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{129, 128, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 129, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 33, 128, 1},
|
||||
{128, 128, 129, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
// If kernel B1Layout is RowMajor, expect not to support odd O size
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 129, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, DISABLED_Bench_BF16)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{256, 256, 64, 64, 768},
|
||||
{256, 256, 128, 128, 768},
|
||||
{512, 512, 64, 64, 768},
|
||||
{512, 512, 128, 128, 768},
|
||||
{1024, 1024, 64, 64, 768},
|
||||
{1024, 1024, 128, 128, 768},
|
||||
{2048, 2048, 64, 64, 768},
|
||||
{2048, 2048, 128, 128, 768},
|
||||
{4096, 4096, 64, 64, 768},
|
||||
{4096, 4096, 128, 128, 768},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = false;
|
||||
this->Run();
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_batched_gemm_gemm_util.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row>,
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Col, Row>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchedGemmGemmFP16, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16)
|
||||
{
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{136, 128, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 136, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 40, 128, 1},
|
||||
{128, 128, 136, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 136, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{129, 128, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 129, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 33, 128, 1},
|
||||
{128, 128, 129, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
// If kernel B1Layout is RowMajor, expect not to support odd O size
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 129, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Bench_FP16)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{256, 256, 64, 64, 768},
|
||||
{256, 256, 128, 128, 768},
|
||||
{512, 512, 64, 64, 768},
|
||||
{512, 512, 128, 128, 768},
|
||||
{1024, 1024, 64, 64, 768},
|
||||
{1024, 1024, 128, 128, 768},
|
||||
{2048, 2048, 64, 64, 768},
|
||||
{2048, 2048, 128, 128, 768},
|
||||
{4096, 4096, 64, 64, 768},
|
||||
{4096, 4096, 128, 128, 768},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = false;
|
||||
this->Run();
|
||||
}
|
||||
@@ -1,8 +1,126 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_batched_gemm_gemm_util.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ALayout = Row;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = F16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = PassThrough;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
|
||||
|
||||
using DeviceGemmGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
32, // Gemm1KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
2, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<8, 32, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
bool IsSupported(int M, int N, int K, int O)
|
||||
{
|
||||
auto gemm = DeviceGemmGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
|
||||
static_cast<B0DataType*>(nullptr),
|
||||
static_cast<B1DataType*>(nullptr),
|
||||
static_cast<CDataType*>(nullptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
0, // BatchCount
|
||||
0, // StrideA
|
||||
0, // StrideB0
|
||||
0, // StrideB1
|
||||
0, // StrideC
|
||||
0, // BatchStrideA
|
||||
0, // BatchStrideB0
|
||||
0, // BatchStrideB1
|
||||
0, // BatchStrideC
|
||||
PassThrough{}, // a_element_op
|
||||
PassThrough{}, // b0_element_op
|
||||
PassThrough{}, // acc0_element_op
|
||||
PassThrough{}, // b1_element_op
|
||||
PassThrough{}); // c_element_op
|
||||
|
||||
return gemm.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm<Tuple>
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <vector>
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
#include "profiler/profile_batched_gemm_gemm_impl.hpp"
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
@@ -13,7 +12,8 @@ using ck::tensor_operation::device::GemmSpecialization;
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -70,120 +70,3 @@ struct TestBatchedGemmGemm : public ::testing::Test
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ALayout = Row;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = F16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = PassThrough;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
|
||||
|
||||
using DeviceGemmGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
32, // Gemm1KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
2, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<8, 32, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
bool IsSupported(int M, int N, int K, int O)
|
||||
{
|
||||
auto gemm = DeviceGemmGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
|
||||
static_cast<B0DataType*>(nullptr),
|
||||
static_cast<B1DataType*>(nullptr),
|
||||
static_cast<CDataType*>(nullptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
0, // BatchCount
|
||||
0, // StrideA
|
||||
0, // StrideB0
|
||||
0, // StrideB1
|
||||
0, // StrideC
|
||||
0, // BatchStrideA
|
||||
0, // BatchStrideB0
|
||||
0, // BatchStrideB1
|
||||
0, // BatchStrideC
|
||||
PassThrough{}, // a_element_op
|
||||
PassThrough{}, // b0_element_op
|
||||
PassThrough{}, // acc0_element_op
|
||||
PassThrough{}, // b1_element_op
|
||||
PassThrough{}); // c_element_op
|
||||
|
||||
return gemm.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user