mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add a gpu gemm reference kernel (#1528)
* Add a gpu gemm reference kernel * Switch to gpu reference in gemm examples * Remove redundant arguments * Update all related examples * Update more examples * Try less threads per block * Try even less threads per block * Add support for all matrix layouts * Increase block size * Clean up * Remove hardcoded strides * Clean up * Try a column-major case * Revert back to row-major * Run both CPU and GPU veriffication --------- Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -21,6 +21,7 @@
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp"
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
@@ -28,9 +29,9 @@ struct ProblemSize final
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
ck::index_t StrideA = 0;
|
||||
ck::index_t StrideB = 0;
|
||||
ck::index_t StrideC = 0;
|
||||
};
|
||||
|
||||
struct ProblemSizeStreamK final
|
||||
@@ -39,9 +40,9 @@ struct ProblemSizeStreamK final
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
ck::index_t StrideA = 0;
|
||||
ck::index_t StrideB = 0;
|
||||
ck::index_t StrideC = 0;
|
||||
|
||||
ck::index_t NumSKBlocks = -1;
|
||||
};
|
||||
@@ -51,9 +52,9 @@ struct ProblemSizeStreamK_universal final
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
ck::index_t StrideA = 0;
|
||||
ck::index_t StrideB = 0;
|
||||
ck::index_t StrideC = 0;
|
||||
|
||||
ck::index_t Grid_size = -1; // defaults to max occupancy
|
||||
ck::index_t Streamk_sel = 1; // defaults to 1-tile SK
|
||||
@@ -65,9 +66,9 @@ struct ProblemSizeSplitK final
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
ck::index_t StrideA = 0;
|
||||
ck::index_t StrideB = 0;
|
||||
ck::index_t StrideC = 0;
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
};
|
||||
@@ -125,7 +126,7 @@ bool parse_cmd_args<ProblemSize>(int argc,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
@@ -175,7 +176,7 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
|
||||
else
|
||||
{
|
||||
std::cerr
|
||||
<< "arg1: verification (0=no, 1=yes)" << std::endl
|
||||
<< "arg1: verification (0=no, 1=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
|
||||
@@ -224,7 +225,7 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
@@ -274,7 +275,7 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
|
||||
Reference in New Issue
Block a user