Implement batched gemm add relu gemm add for rdna4 (#3391)

* wip: test suite for batched gemm multiple d gemm multiple d, working on gridwise implenentation

* wip: many fixes in implementation of batched gemm gemm multiple d

* wip: batched gemm gemm multiple d gridwise op compiling, not working yet

* fix: incorrect d0 grid indexing in batched gemm gemm multipled

* feat: add instances for batched gemm add relu gemm add

* chore: configure instance with low vector transfer size for odd sizes

* chore: add some more validation to device batched gemm gemm multiple d, and removed template parameter that didn't really make sense

* fix: upate device_batched_gemm_gemm_wmma to work with new gridwise changes

* fix: disable odd size tests on XDL archs

* chore: removed temporary logging

* chore: update some references to C tensor to E tensor

* Tentative fix for example template params

* Tentative fix for non-multi-D batched gemm gemm device impl.

* Tentative fix for xdl example template params

* Tentative fix for profiler build on gfx90a

* chore: improve device batched gemm gemm multi D comment to include all ops and dimensions

* chore: explicitly call ck::make_tuple to prevent issues when std::make_tuple would apply

* fix: make the gemm1 data types match what happens in the device op

* feat: add d0s/d1s datatypes and layouts to the device op type string

* chore: change element-wise op so addition happens in fp32

* chore: add static asserts for gemm0/gemm1 calculated wave sizes

* chore: also updated other element-wise ops to use fp32 calculations

* chore: log number of supported instances

* chore: update instance comment

* chore: disable kernel timing in example by default

* fix: gemm1 wave size calculation

* fix: make sure batched gemm multiple d gemm multiple d profiler performs correct type conversions

* chore: remove increased tolerance in batched gemm gemm multiple d example

* chore: add comment explaining that verification fails for certain input values

* chore: clarify instance comment

---------

Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>
This commit is contained in:
Erwin Terpstra
2026-01-20 22:06:59 +01:00
committed by GitHub
parent 91b4102a59
commit d5ae81b292
22 changed files with 2956 additions and 499 deletions

View File

@@ -275,6 +275,7 @@ add_subdirectory(batched_contraction)
add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(batched_gemm_gemm)
add_subdirectory(batched_gemm_multiple_d_gemm_multiple_d)
add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(batched_gemm_softmax_gemm_permute)
add_subdirectory(batched_gemm_b_scale)

View File

@@ -0,0 +1,12 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary
# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link
# the instance library if there's no instances present for the current arch.
if (CK_USE_XDL OR CK_USE_WMMA)
add_gtest_executable(test_batched_gemm_add_relu_gemm_add test_batched_gemm_add_relu_gemm_add.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_add_relu_gemm_add PRIVATE utility device_batched_gemm_add_relu_gemm_add_instance)
endif()
endif()

View File

@@ -0,0 +1,27 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_batched_gemm_multiple_d_gemm_multiple_d.hpp"
template <typename Tuple>
class TestBatchedGemmMultipleDGemmMultipleD
: public BaseTestBatchedGemmMultipleDGemmMultipleD<Tuple>
{
};
using A0ElementOp = ck::tensor_operation::element_wise::PassThrough;
using B0ElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using B1ElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple<F16, F16, ck::Tuple<F16>, F16, ck::Tuple<F16>, F16, Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, A0ElementOp, B0ElementOp, CDE0ElementOp, B1ElementOp, CDE1ElementOp>,
std::tuple<F16, F16, ck::Tuple<F16>, F16, ck::Tuple<F16>, F16, Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, A0ElementOp, B0ElementOp, CDE0ElementOp, B1ElementOp, CDE1ElementOp>
>;
// clang-format on
TYPED_TEST_SUITE(TestBatchedGemmMultipleDGemmMultipleD, KernelTypes);
#include "test_batched_gemm_multiple_d_gemm_multiple_d_ut_cases.inc"

View File

@@ -0,0 +1,121 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include <vector>
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "profiler/profile_batched_gemm_multiple_d_gemm_multiple_d_impl.hpp"
using ck::tensor_operation::device::GemmSpecialization;
template <ck::index_t N>
using I = ck::Number<N>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple>
struct BaseTestBatchedGemmMultipleDGemmMultipleD : public ::testing::Test
{
using ADataType = std::tuple_element_t<0, Tuple>;
using B0DataType = std::tuple_element_t<1, Tuple>;
using D0sDataType = std::tuple_element_t<2, Tuple>;
using B1DataType = std::tuple_element_t<3, Tuple>;
using D1sDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
using ALayout = std::tuple_element_t<6, Tuple>;
using B0Layout = std::tuple_element_t<7, Tuple>;
using D0sLayout = std::tuple_element_t<8, Tuple>;
using B1Layout = std::tuple_element_t<9, Tuple>;
using D1sLayout = std::tuple_element_t<10, Tuple>;
using ELayout = std::tuple_element_t<11, Tuple>;
using A0ElementOp = std::tuple_element_t<12, Tuple>;
using B0ElementOp = std::tuple_element_t<13, Tuple>;
using CDE0ElementOp = std::tuple_element_t<14, Tuple>;
using B1ElementOp = std::tuple_element_t<15, Tuple>;
using CDE1ElementOp = std::tuple_element_t<16, Tuple>;
std::vector<std::vector<int>> lengths_ = {
{256, 256, 64, 64, 4},
{256, 256, 128, 128, 4},
{512, 512, 64, 64, 2},
{512, 512, 128, 128, 2},
{1024, 1024, 64, 64, 1},
{1024, 1024, 128, 128, 1},
};
bool bench_ = false;
bool verify_ = true;
void RunSingle(int M, int N, int K, int O, int BatchCount)
{
// WMMA instances are setup to support all the test cases
// XDL instances are not.
bool fail_if_no_supported_instances = ck::is_gfx11_supported() || ck::is_gfx12_supported();
bool pass =
ck::profiler::profile_batched_gemm_multiple_d_gemm_multiple_d_impl<ALayout,
B0Layout,
D0sLayout,
B1Layout,
D1sLayout,
ELayout,
ADataType,
B0DataType,
D0sDataType,
B1DataType,
D1sDataType,
EDataType,
A0ElementOp,
B0ElementOp,
CDE0ElementOp,
B1ElementOp,
CDE1ElementOp>(
verify_,
1,
false,
bench_,
M,
N,
K,
O,
BatchCount,
-1,
-1,
-1,
-1,
-1,
-1,
-1,
-1,
-1,
-1,
-1,
-1,
fail_if_no_supported_instances);
EXPECT_TRUE(pass);
}
void Run()
{
for(auto lengths : this->lengths_)
{
int M = lengths[0];
int N = lengths[1];
int K = lengths[2];
int O = lengths[3];
int BatchCount = lengths[4];
this->RunSingle(M, N, K, O, BatchCount);
}
}
};

View File

@@ -0,0 +1,88 @@
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test) { this->Run(); }
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadM)
{
this->lengths_ = std::vector<std::vector<int>>{
{136, 128, 32, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadN)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 136, 32, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadK)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 40, 128, 1},
{128, 128, 136, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadO)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 136, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddM)
{
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
GTEST_SKIP() << "Odd-sizes only supported on WMMA instances.";
}
this->lengths_ = std::vector<std::vector<int>>{
{129, 128, 32, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddN)
{
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
GTEST_SKIP() << "Odd-sizes only supported on WMMA instances.";
}
this->lengths_ = std::vector<std::vector<int>>{
{128, 129, 32, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddK)
{
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
GTEST_SKIP() << "Odd-sizes only supported on WMMA instances.";
}
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 33, 128, 1},
{128, 128, 129, 128, 1},
};
this->Run();
}
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddO)
{
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
GTEST_SKIP() << "Odd-sizes only supported on WMMA instances.";
}
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 129, 1},
};
this->Run();
}