Fused GEMM+GEMM (#351)

* initial stub for gemm_gemm_xdl_cshuffle

* set up example code

* compiles

* prevent integer overflow

* harmonize interface between ref_gemm and ref_batched_gemm

* batched_gemm_gemm

* fix example

* host tensor gen: diagonal pattern in lowest two-dimensions only

* make c descriptors containing only integral constants

* clean up

* add BlockwiseGemmXdlops_v2 while exploring an unified approach

* implement proper interface

* tidy up example

* fix compilation warnings

* coarsely controlled 2nd gemm padding

* remove rocm-cmake's hard requirement for certain revision

* clang-format

* resolve merge conflict

* fix compilation error on gfx10

* adds acc0 elementwise op to interface

* add gemm_gemm instances and tests

* avoid LDS data hazard

* fix build

Co-authored-by: Chao Liu <chao.liu2@amd.com>

[ROCm/composable_kernel commit: c20a75b07d]
This commit is contained in:
Anthony Chang
2022-08-13 22:18:58 +08:00
committed by GitHub
parent 4dab000987
commit bd30eaf33b
32 changed files with 2822 additions and 689 deletions

View File

@@ -40,6 +40,7 @@ add_subdirectory(gemm_split_k)
add_subdirectory(gemm_reduce)
add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(batched_gemm_gemm)
add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(grouped_gemm)
add_subdirectory(reduce)

View File

@@ -0,0 +1,5 @@
add_custom_target(test_batched_gemm_gemm)
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp)
target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance)
add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16)

View File

@@ -0,0 +1,39 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_batched_gemm_gemm_util.hpp"
template <typename Tuple>
class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm<Tuple>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row>
>;
// clang-format on
TYPED_TEST_SUITE(TestBatchedGemmGemmFP16, KernelTypes);
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16) { this->Run(); }
TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Bench_FP16)
{
this->lengths_ = std::vector<std::vector<int>>{
{256, 256, 64, 64, 768},
{256, 256, 128, 128, 768},
{512, 512, 64, 64, 768},
{512, 512, 128, 128, 768},
{1024, 1024, 64, 64, 768},
{1024, 1024, 128, 128, 768},
{2048, 2048, 64, 64, 768},
{2048, 2048, 128, 128, 768},
{4096, 4096, 64, 64, 768},
{4096, 4096, 128, 128, 768},
};
this->bench_ = true;
this->verify_ = false;
this->Run();
}

View File

@@ -0,0 +1,68 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include "profiler/include/profile_batched_gemm_gemm_impl.hpp"
template <ck::index_t N>
using I = ck::Number<N>;
using F16 = ck::half_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple>
struct TestBatchedGemmGemm : public ::testing::Test
{
using ADataType = std::tuple_element_t<0, Tuple>;
using B0DataType = std::tuple_element_t<1, Tuple>;
using B1DataType = std::tuple_element_t<2, Tuple>;
using CDataType = std::tuple_element_t<3, Tuple>;
using ALayout = std::tuple_element_t<4, Tuple>;
using B0Layout = std::tuple_element_t<5, Tuple>;
using B1Layout = std::tuple_element_t<6, Tuple>;
using CLayout = std::tuple_element_t<7, Tuple>;
std::vector<std::vector<int>> lengths_ = {
{256, 256, 64, 64, 4},
{256, 256, 128, 128, 4},
{512, 512, 64, 64, 2},
{512, 512, 128, 128, 2},
{1024, 1024, 64, 64, 1},
{1024, 1024, 128, 128, 1},
};
bool bench_ = false;
bool verify_ = true;
void RunSingle(int M, int N, int K, int O, int BatchCount)
{
bool pass = ck::profiler::profile_batched_gemm_gemm_impl<ADataType,
B0DataType,
B1DataType,
CDataType,
ALayout,
B0Layout,
B1Layout,
CLayout>(
verify_, 1, false, bench_, M, N, K, O, BatchCount);
EXPECT_TRUE(pass);
}
void Run()
{
for(auto lengths : this->lengths_)
{
int M = lengths[0];
int N = lengths[1];
int K = lengths[2];
int O = lengths[3];
int BatchCount = lengths[4];
this->RunSingle(M, N, K, O, BatchCount);
}
}
};