mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Gemm+Reduce Fusion (#128)
* add gridwise gemm v4r1 * rename * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * use sfc in shuffling * remove hardcode * remove hardcode * refactor * fix build * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * adding gemm+reduce * format * clean * adding gemm+reduce * adding profiler for gemm+reduce * adding gemm+reduce profiler * fix build * clean up * gemm+reduce * fix build * update DeviceGemm_Xdl_CShuffle; update enum to enum class * clean up * add test for gemm+reduce * clean up * refactor * fix build * fix build
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
#include "tensor_layout.hpp"
|
||||
#include "device_gemm_xdl_splitk.hpp"
|
||||
|
||||
enum GemmMatrixLayout
|
||||
enum struct GemmMatrixLayout
|
||||
{
|
||||
MK_KN_MN, // 0
|
||||
MK_NK_MN, // 1
|
||||
@@ -59,7 +59,7 @@ static bool check_out(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
|
||||
struct gemmArgs
|
||||
{
|
||||
int layout;
|
||||
GemmMatrixLayout layout;
|
||||
int M;
|
||||
int N;
|
||||
int K;
|
||||
@@ -216,13 +216,13 @@ int main(int argc, char* argv[])
|
||||
std::vector<gemmArgs> test_cases;
|
||||
if(argc == 1)
|
||||
{
|
||||
test_cases = {{0, 3, 3, 3, 3, 3, 3, 1}};
|
||||
test_cases = {{GemmMatrixLayout::MK_KN_MN, 3, 3, 3, 3, 3, 3, 1}};
|
||||
// JD: Populate with more and meaningful
|
||||
return 0;
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
const int layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
|
||||
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[1]));
|
||||
|
||||
const int M = std::stoi(argv[2]);
|
||||
const int N = std::stoi(argv[3]);
|
||||
|
||||
Reference in New Issue
Block a user