mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
Gemm+Reduce Fusion (#128)
* add gridwise gemm v4r1
* rename
* adding gemm+reduce
* adding gemm+reduce
* adding gemm+reduce
* adding gemm+reduce
* use sfc in shuffling
* remove hardcode
* remove hardcode
* refactor
* fix build
* adding gemm+reduce
* adding gemm+reduce
* adding gemm+reduce
* adding gemm+reduce
* adding gemm+reduce
* format
* clean
* adding gemm+reduce
* adding profiler for gemm+reduce
* adding gemm+reduce profiler
* fix build
* clean up
* gemm+reduce
* fix build
* update DeviceGemm_Xdl_CShuffle; update enum to enum class
* clean up
* add test for gemm+reduce
* clean up
* refactor
* fix build
* fix build
[ROCm/composable_kernel commit: f95267f166]
This commit is contained in:
@@ -16,6 +16,7 @@ include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu
|
||||
${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu
|
||||
${PROJECT_SOURCE_DIR}/test/include
|
||||
${PROJECT_SOURCE_DIR}/profiler/include
|
||||
${PROJECT_SOURCE_DIR}/external/include/half
|
||||
)
|
||||
|
||||
@@ -35,9 +36,10 @@ add_subdirectory(space_filling_curve)
|
||||
add_subdirectory(conv_util)
|
||||
add_subdirectory(reference_conv_fwd)
|
||||
add_subdirectory(gemm)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(gemm_split_k)
|
||||
add_subdirectory(gemm_reduce)
|
||||
add_subdirectory(batched_gemm)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(convnd_fwd)
|
||||
add_subdirectory(conv2d_bwd_data)
|
||||
add_subdirectory(batched_gemm)
|
||||
add_subdirectory(reduce)
|
||||
|
||||
9
test/gemm_reduce/CMakeLists.txt
Normal file
9
test/gemm_reduce/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
include_directories(BEFORE
|
||||
${PROJECT_SOURCE_DIR}/profiler/include
|
||||
${PROJECT_SOURCE_DIR}/test/include
|
||||
${PROJECT_SOURCE_DIR}/external/include/half
|
||||
)
|
||||
|
||||
add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp)
|
||||
target_link_libraries(test_gemm_reduce_fp16 PRIVATE host_tensor)
|
||||
target_link_libraries(test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance)
|
||||
52
test/gemm_reduce/gemm_reduce_fp16.cpp
Normal file
52
test/gemm_reduce/gemm_reduce_fp16.cpp
Normal file
@@ -0,0 +1,52 @@
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <half.hpp>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "profile_gemm_reduce_impl.hpp"
|
||||
|
||||
int main()
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
int M = 512;
|
||||
int N = 256;
|
||||
int K = 128;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
pass = pass &&
|
||||
ck::profiler::
|
||||
profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Row, Row, Row>(
|
||||
true, 1, false, 1, M, N, K, K, N, N);
|
||||
|
||||
pass = pass &&
|
||||
ck::profiler::
|
||||
profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Row, Col, Row>(
|
||||
true, 1, false, 1, M, N, K, K, K, N);
|
||||
|
||||
pass = pass &&
|
||||
ck::profiler::
|
||||
profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Col, Row, Row>(
|
||||
true, 1, false, 1, M, N, K, M, N, N);
|
||||
|
||||
pass = pass &&
|
||||
ck::profiler::
|
||||
profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Col, Col, Row>(
|
||||
true, 1, false, 1, M, N, K, M, K, N);
|
||||
|
||||
if(pass)
|
||||
{
|
||||
std::cout << "test GEMM+Reduce fp16: Pass" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "test GEMM+Reduce fp16: Fail" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "tensor_layout.hpp"
|
||||
#include "device_gemm_xdl_splitk.hpp"
|
||||
|
||||
enum GemmMatrixLayout
|
||||
enum struct GemmMatrixLayout
|
||||
{
|
||||
MK_KN_MN, // 0
|
||||
MK_NK_MN, // 1
|
||||
@@ -59,7 +59,7 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
|
||||
struct gemmArgs
|
||||
{
|
||||
int layout;
|
||||
GemmMatrixLayout layout;
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
@@ -216,13 +216,13 @@ int main(int argc, char* argv[])
|
||||
std::vector<gemmArgs> test_cases;
|
||||
if(argc == 1)
|
||||
{
|
||||
test_cases = {{0, 3, 3, 3, 3, 3, 3, 1}};
|
||||
test_cases = {{GemmMatrixLayout::MK_KN_MN, 3, 3, 3, 3, 3, 3, 1}};
|
||||
// JD: Populate with more and meaningful
|
||||
return 0;
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
|
||||
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
|
||||
|
||||
const int M = std::stoi(argv[2]);
|
||||
const int N = std::stoi(argv[3]);
|
||||
|
||||
Reference in New Issue
Block a user