mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Fused GEMM+GEMM (#351)
* initial stub for gemm_gemm_xdl_cshuffle
* set up example code
* compiles
* prevent integer overflow
* harmonize interface between ref_gemm and ref_batched_gemm
* batched_gemm_gemm
* fix example
* host tensor gen: diagonal pattern in lowest two-dimensions only
* make c descriptors containing only integral constants
* clean up
* add BlockwiseGemmXdlops_v2 while exploring an unified approach
* implement proper interface
* tidy up example
* fix compilation warnings
* coarsely controlled 2nd gemm padding
* remove rocm-cmake's hard requirement for certain revision
* clang-format
* resolve merge conflict
* fix compilation error on gfx10
* adds acc0 elementwise op to interface
* add gemm_gemm instances and tests
* avoid LDS data hazard
* fix build
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: c20a75b07d]
This commit is contained in:
@@ -40,6 +40,7 @@ add_subdirectory(gemm_split_k)
|
||||
add_subdirectory(gemm_reduce)
|
||||
add_subdirectory(batched_gemm)
|
||||
add_subdirectory(batched_gemm_reduce)
|
||||
add_subdirectory(batched_gemm_gemm)
|
||||
add_subdirectory(batched_gemm_softmax_gemm)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(reduce)
|
||||
|
||||
5
test/batched_gemm_gemm/CMakeLists.txt
Normal file
5
test/batched_gemm_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
add_custom_target(test_batched_gemm_gemm)
|
||||
|
||||
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp)
|
||||
target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16)
|
||||
39
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
Normal file
39
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
Normal file
@@ -0,0 +1,39 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_batched_gemm_gemm_util.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchedGemmGemmFP16, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16) { this->Run(); }
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Bench_FP16)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{256, 256, 64, 64, 768},
|
||||
{256, 256, 128, 128, 768},
|
||||
{512, 512, 64, 64, 768},
|
||||
{512, 512, 128, 128, 768},
|
||||
{1024, 1024, 64, 64, 768},
|
||||
{1024, 1024, 128, 128, 768},
|
||||
{2048, 2048, 64, 64, 768},
|
||||
{2048, 2048, 128, 128, 768},
|
||||
{4096, 4096, 64, 64, 768},
|
||||
{4096, 4096, 128, 128, 768},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = false;
|
||||
this->Run();
|
||||
}
|
||||
68
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
Normal file
68
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
Normal file
@@ -0,0 +1,68 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <vector>
|
||||
#include "profiler/include/profile_batched_gemm_gemm_impl.hpp"
|
||||
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <typename Tuple>
|
||||
struct TestBatchedGemmGemm : public ::testing::Test
|
||||
{
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using B0DataType = std::tuple_element_t<1, Tuple>;
|
||||
using B1DataType = std::tuple_element_t<2, Tuple>;
|
||||
using CDataType = std::tuple_element_t<3, Tuple>;
|
||||
using ALayout = std::tuple_element_t<4, Tuple>;
|
||||
using B0Layout = std::tuple_element_t<5, Tuple>;
|
||||
using B1Layout = std::tuple_element_t<6, Tuple>;
|
||||
using CLayout = std::tuple_element_t<7, Tuple>;
|
||||
|
||||
std::vector<std::vector<int>> lengths_ = {
|
||||
{256, 256, 64, 64, 4},
|
||||
{256, 256, 128, 128, 4},
|
||||
{512, 512, 64, 64, 2},
|
||||
{512, 512, 128, 128, 2},
|
||||
{1024, 1024, 64, 64, 1},
|
||||
{1024, 1024, 128, 128, 1},
|
||||
};
|
||||
bool bench_ = false;
|
||||
bool verify_ = true;
|
||||
|
||||
void RunSingle(int M, int N, int K, int O, int BatchCount)
|
||||
{
|
||||
bool pass = ck::profiler::profile_batched_gemm_gemm_impl<ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CLayout>(
|
||||
verify_, 1, false, bench_, M, N, K, O, BatchCount);
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
void Run()
|
||||
{
|
||||
for(auto lengths : this->lengths_)
|
||||
{
|
||||
int M = lengths[0];
|
||||
int N = lengths[1];
|
||||
int K = lengths[2];
|
||||
int O = lengths[3];
|
||||
int BatchCount = lengths[4];
|
||||
|
||||
this->RunSingle(M, N, K, O, BatchCount);
|
||||
}
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user