mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
Gemm+Reduce Fusion (#128)
* add gridwise gemm v4r1 * rename * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * use sfc in shuffling * remove hardcode * remove hardcode * refactor * fix build * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * format * clean * adding gemm+reduce * adding profiler for gemm+reduce * adding gemm+reduce profiler * fix build * clean up * gemm+reduce * fix build * update DeviceGemm_Xdl_CShuffle; update enum to enum class * clean up * add test for gemm+reduce * clean up * refactor * fix build * fix build
This commit is contained in:
52
test/gemm_reduce/gemm_reduce_fp16.cpp
Normal file
52
test/gemm_reduce/gemm_reduce_fp16.cpp
Normal file
@@ -0,0 +1,52 @@
|
||||
#include <algorithm>
|
||||
#include <cstdlib>
|
||||
#include <half.hpp>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "profile_gemm_reduce_impl.hpp"
|
||||
|
||||
int main()
|
||||
{
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
int M = 512;
|
||||
int N = 256;
|
||||
int K = 128;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
pass = pass &&
|
||||
ck::profiler::
|
||||
profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Row, Row, Row>(
|
||||
true, 1, false, 1, M, N, K, K, N, N);
|
||||
|
||||
pass = pass &&
|
||||
ck::profiler::
|
||||
profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Row, Col, Row>(
|
||||
true, 1, false, 1, M, N, K, K, K, N);
|
||||
|
||||
pass = pass &&
|
||||
ck::profiler::
|
||||
profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Col, Row, Row>(
|
||||
true, 1, false, 1, M, N, K, M, N, N);
|
||||
|
||||
pass = pass &&
|
||||
ck::profiler::
|
||||
profile_gemm_reduce_impl<ck::half_t, ck::half_t, ck::half_t, float, Col, Col, Row>(
|
||||
true, 1, false, 1, M, N, K, M, K, N);
|
||||
|
||||
if(pass)
|
||||
{
|
||||
std::cout << "test GEMM+Reduce fp16: Pass" << std::endl;
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "test GEMM+Reduce fp16: Fail" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user