mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 13:29:20 +00:00
[CK TILE] Refactor GemmKernel to be reused by other GEMM related operators (#1730)
* Gemm Kernel Refactor part1
* Gemm Kernel Refactor common gemm pipeline part2
* [CK TILE] Refactor batched gemm to reuse GemmKernel
* [CK TILE] Refactor GemmKernel - review changes part1
* [CK TILE] Refactor GemmKernel - references fix
* [CK TILE] Refactor GemmKernel - naming changes, add problem
* [CK_TILE] Refactor GemmKernel - update tests
* [CK_TILE] Refactor GemmKernel - review changes
* [CK_TILE] Refactor GemmKernel - update test
* [CK_TILE] Refactor GemmKernel - constness fixes
* [CK_TILE] Refactor GemmKernel - update tests
[ROCm/composable_kernel commit: 453ca37347]
This commit is contained in:
@@ -24,12 +24,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
|
||||
struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& s)
|
||||
void invoke_batched_gemm(const ck_tile::BatchedGemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadM = false;
|
||||
@@ -94,9 +91,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
|
||||
using Kernel =
|
||||
ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(args);
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.batch_count);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
@@ -185,21 +182,22 @@ class TestCkTileBatchedGemm : public ::testing::Test
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
batched_gemm_kargs kargs{a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
BatchStrideA,
|
||||
BatchStrideB,
|
||||
BatchStrideC,
|
||||
BatchCount};
|
||||
ck_tile::BatchedGemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.stride_A = StrideA;
|
||||
args.stride_B = StrideB;
|
||||
args.stride_C = StrideC;
|
||||
args.batch_stride_A = BatchStrideA;
|
||||
args.batch_stride_B = BatchStrideB;
|
||||
args.batch_stride_C = BatchStrideC;
|
||||
args.batch_count = BatchCount;
|
||||
|
||||
invoke_batched_gemm<ALayout, BLayout, CLayout>(kargs,
|
||||
invoke_batched_gemm<ALayout, BLayout, CLayout>(args,
|
||||
ck_tile::stream_config{nullptr, false});
|
||||
|
||||
std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K
|
||||
|
||||
@@ -31,22 +31,8 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
static constexpr auto PipelineType = std::tuple_element_t<8, Tuple>::value;
|
||||
// TODO: expose tile size through test t-param ?
|
||||
|
||||
struct gemm_args
|
||||
{
|
||||
const void* p_a;
|
||||
const void* p_b;
|
||||
void* p_c;
|
||||
ck_tile::index_t kbatch;
|
||||
ck_tile::index_t M;
|
||||
ck_tile::index_t N;
|
||||
ck_tile::index_t K;
|
||||
ck_tile::index_t stride_A;
|
||||
ck_tile::index_t stride_B;
|
||||
ck_tile::index_t stride_C;
|
||||
};
|
||||
|
||||
template <bool PadM, bool PadN, bool PadK>
|
||||
void invoke_gemm(const gemm_args& args, const ck_tile::stream_config& s)
|
||||
void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// TODO: This should be parameterized in tests
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
@@ -117,17 +103,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(args.p_a,
|
||||
args.p_b,
|
||||
args.p_c,
|
||||
args.M,
|
||||
args.N,
|
||||
args.K,
|
||||
args.stride_A,
|
||||
args.stride_B,
|
||||
args.stride_C);
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.kbatch);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
@@ -319,11 +297,11 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
gemm_args args;
|
||||
args.p_a = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.p_b = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.p_c = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.kbatch = kbatch;
|
||||
ck_tile::GemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
|
||||
Reference in New Issue
Block a user