mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Implement batched gemm add relu gemm add for rdna4 (#3391)
* wip: test suite for batched gemm multiple d gemm multiple d, working on gridwise implenentation * wip: many fixes in implementation of batched gemm gemm multiple d * wip: batched gemm gemm multiple d gridwise op compiling, not working yet * fix: incorrect d0 grid indexing in batched gemm gemm multipled * feat: add instances for batched gemm add relu gemm add * chore: configure instance with low vector transfer size for odd sizes * chore: add some more validation to device batched gemm gemm multiple d, and removed template parameter that didn't really make sense * fix: upate device_batched_gemm_gemm_wmma to work with new gridwise changes * fix: disable odd size tests on XDL archs * chore: removed temporary logging * chore: update some references to C tensor to E tensor * Tentative fix for example template params * Tentative fix for non-multi-D batched gemm gemm device impl. * Tentative fix for xdl example template params * Tentative fix for profiler build on gfx90a * chore: improve device batched gemm gemm multi D comment to include all ops and dimensions * chore: explicitly call ck::make_tuple to prevent issues when std::make_tuple would apply * fix: make the gemm1 data types match what happens in the device op * feat: add d0s/d1s datatypes and layouts to the device op type string * chore: change element-wise op so addition happens in fp32 * chore: add static asserts for gemm0/gemm1 calculated wave sizes * chore: also updated other element-wise ops to use fp32 calculations * chore: log number of supported instances * chore: update instance comment * chore: disable kernel timing in example by default * fix: gemm1 wave size calculation * fix: make sure batched gemm multiple d gemm multiple d profiler performs correct type conversions * chore: remove increased tolerance in batched gemm gemm multiple d example * chore: add comment explaining that verification fails for certain input values * chore: clarify instance comment --------- Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>
This commit is contained in:
@@ -275,6 +275,7 @@ add_subdirectory(batched_contraction)
|
||||
add_subdirectory(batched_gemm)
|
||||
add_subdirectory(batched_gemm_reduce)
|
||||
add_subdirectory(batched_gemm_gemm)
|
||||
add_subdirectory(batched_gemm_multiple_d_gemm_multiple_d)
|
||||
add_subdirectory(batched_gemm_softmax_gemm)
|
||||
add_subdirectory(batched_gemm_softmax_gemm_permute)
|
||||
add_subdirectory(batched_gemm_b_scale)
|
||||
|
||||
12
test/batched_gemm_multiple_d_gemm_multiple_d/CMakeLists.txt
Normal file
12
test/batched_gemm_multiple_d_gemm_multiple_d/CMakeLists.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary
|
||||
# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link
|
||||
# the instance library if there's no instances present for the current arch.
|
||||
if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
add_gtest_executable(test_batched_gemm_add_relu_gemm_add test_batched_gemm_add_relu_gemm_add.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_add_relu_gemm_add PRIVATE utility device_batched_gemm_add_relu_gemm_add_instance)
|
||||
endif()
|
||||
endif()
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_batched_gemm_multiple_d_gemm_multiple_d.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmMultipleDGemmMultipleD
|
||||
: public BaseTestBatchedGemmMultipleDGemmMultipleD<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
using A0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using B0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
|
||||
|
||||
using B1ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<F16, F16, ck::Tuple<F16>, F16, ck::Tuple<F16>, F16, Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, A0ElementOp, B0ElementOp, CDE0ElementOp, B1ElementOp, CDE1ElementOp>,
|
||||
std::tuple<F16, F16, ck::Tuple<F16>, F16, ck::Tuple<F16>, F16, Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, A0ElementOp, B0ElementOp, CDE0ElementOp, B1ElementOp, CDE1ElementOp>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchedGemmMultipleDGemmMultipleD, KernelTypes);
|
||||
#include "test_batched_gemm_multiple_d_gemm_multiple_d_ut_cases.inc"
|
||||
@@ -0,0 +1,121 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "profiler/profile_batched_gemm_multiple_d_gemm_multiple_d_impl.hpp"
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <typename Tuple>
|
||||
struct BaseTestBatchedGemmMultipleDGemmMultipleD : public ::testing::Test
|
||||
{
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using B0DataType = std::tuple_element_t<1, Tuple>;
|
||||
using D0sDataType = std::tuple_element_t<2, Tuple>;
|
||||
using B1DataType = std::tuple_element_t<3, Tuple>;
|
||||
using D1sDataType = std::tuple_element_t<4, Tuple>;
|
||||
using EDataType = std::tuple_element_t<5, Tuple>;
|
||||
using ALayout = std::tuple_element_t<6, Tuple>;
|
||||
using B0Layout = std::tuple_element_t<7, Tuple>;
|
||||
using D0sLayout = std::tuple_element_t<8, Tuple>;
|
||||
using B1Layout = std::tuple_element_t<9, Tuple>;
|
||||
using D1sLayout = std::tuple_element_t<10, Tuple>;
|
||||
using ELayout = std::tuple_element_t<11, Tuple>;
|
||||
using A0ElementOp = std::tuple_element_t<12, Tuple>;
|
||||
using B0ElementOp = std::tuple_element_t<13, Tuple>;
|
||||
using CDE0ElementOp = std::tuple_element_t<14, Tuple>;
|
||||
using B1ElementOp = std::tuple_element_t<15, Tuple>;
|
||||
using CDE1ElementOp = std::tuple_element_t<16, Tuple>;
|
||||
|
||||
std::vector<std::vector<int>> lengths_ = {
|
||||
{256, 256, 64, 64, 4},
|
||||
{256, 256, 128, 128, 4},
|
||||
{512, 512, 64, 64, 2},
|
||||
{512, 512, 128, 128, 2},
|
||||
{1024, 1024, 64, 64, 1},
|
||||
{1024, 1024, 128, 128, 1},
|
||||
};
|
||||
bool bench_ = false;
|
||||
bool verify_ = true;
|
||||
|
||||
void RunSingle(int M, int N, int K, int O, int BatchCount)
|
||||
{
|
||||
// WMMA instances are setup to support all the test cases
|
||||
// XDL instances are not.
|
||||
bool fail_if_no_supported_instances = ck::is_gfx11_supported() || ck::is_gfx12_supported();
|
||||
|
||||
bool pass =
|
||||
ck::profiler::profile_batched_gemm_multiple_d_gemm_multiple_d_impl<ALayout,
|
||||
B0Layout,
|
||||
D0sLayout,
|
||||
B1Layout,
|
||||
D1sLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
D0sDataType,
|
||||
B1DataType,
|
||||
D1sDataType,
|
||||
EDataType,
|
||||
A0ElementOp,
|
||||
B0ElementOp,
|
||||
CDE0ElementOp,
|
||||
B1ElementOp,
|
||||
CDE1ElementOp>(
|
||||
verify_,
|
||||
1,
|
||||
false,
|
||||
bench_,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
fail_if_no_supported_instances);
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
void Run()
|
||||
{
|
||||
for(auto lengths : this->lengths_)
|
||||
{
|
||||
int M = lengths[0];
|
||||
int N = lengths[1];
|
||||
int K = lengths[2];
|
||||
int O = lengths[3];
|
||||
int BatchCount = lengths[4];
|
||||
|
||||
this->RunSingle(M, N, K, O, BatchCount);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,88 @@
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test) { this->Run(); }
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{136, 128, 32, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 136, 32, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 40, 128, 1},
|
||||
{128, 128, 136, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 136, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddM)
|
||||
{
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
GTEST_SKIP() << "Odd-sizes only supported on WMMA instances.";
|
||||
}
|
||||
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{129, 128, 32, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddN)
|
||||
{
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
GTEST_SKIP() << "Odd-sizes only supported on WMMA instances.";
|
||||
}
|
||||
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 129, 32, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddK)
|
||||
{
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
GTEST_SKIP() << "Odd-sizes only supported on WMMA instances.";
|
||||
}
|
||||
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 33, 128, 1},
|
||||
{128, 128, 129, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddO)
|
||||
{
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
GTEST_SKIP() << "Odd-sizes only supported on WMMA instances.";
|
||||
}
|
||||
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 129, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
Reference in New Issue
Block a user