mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK-Tile] Universal gemm memory bound pipeline (#1558)
* CK-Tile GEMM with memory bound pipeline. * Memory bound gemm pipeline. * Fix not closed namespace. * Block gemm mem pipeline draft. * Do not use ck_tile:: within ck_tile namespace. * Refactoring & Move Layout info to pipeline problem. * Get hot loop and TailNum information before lunching kernel. * Fixes in pipeline. * Add comment to load_tile_raw and change variable naming style. * Few small changes & formatting. * Do not use macro. * Add gtests. * Use AccDataType for Output of MFMA instruction. * Formatting. * Refactor gemm examples. * Switch over to current block gemm. * Use currently available pipeline policy. * Refactoring and review comment.s * Fixes after merge. * Add missing include. * Add load tile overload which accepts output tensor as parameter. * This give 8% perf boost at the cost of using more registers. * Rename example. * Small changes. * Fix compilation err and lower K. * Support different layouts for A/B * Fix vector size for different layouts. * Rename Alignment into VectorSize * Unblock tests.
This commit is contained in:
@@ -3,11 +3,12 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include <iostream>
|
||||
|
||||
#include <string>
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -17,20 +18,19 @@ struct GemmKernel
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
|
||||
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using LayoutA = remove_cvref_t<typename GemmPipeline::LayoutA>;
|
||||
using LayoutB = remove_cvref_t<typename GemmPipeline::LayoutB>;
|
||||
using LayoutC = remove_cvref_t<typename GemmPipeline::LayoutC>;
|
||||
|
||||
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size)
|
||||
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
return TilePartitioner::GridSize(M_size, N_size, Batch_size);
|
||||
return TilePartitioner::GridSize(M, N, KBatch);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
|
||||
@@ -40,34 +40,30 @@ struct GemmKernel
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
void* c_ptr;
|
||||
|
||||
float epsilon;
|
||||
|
||||
ck_tile::index_t M;
|
||||
ck_tile::index_t N;
|
||||
ck_tile::index_t K;
|
||||
ck_tile::index_t stride_A;
|
||||
ck_tile::index_t stride_B;
|
||||
ck_tile::index_t stride_C;
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr,
|
||||
const void* b_ptr,
|
||||
void* c_ptr,
|
||||
float epsilon,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C)
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_A,
|
||||
index_t stride_B,
|
||||
index_t stride_C)
|
||||
{
|
||||
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, epsilon, M, N, K, stride_A, stride_B, stride_C};
|
||||
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return ck_tile::max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
|
||||
@@ -78,43 +74,43 @@ struct GemmKernel
|
||||
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
|
||||
// Convert pointers to tensor views
|
||||
auto a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<LayoutA, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_start,
|
||||
make_tuple(kargs.M, kargs.K),
|
||||
make_tuple(1, kargs.stride_A),
|
||||
number<GemmPipeline::AlignmentA>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_start,
|
||||
make_tuple(kargs.M, kargs.K),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<GemmPipeline::AlignmentA>{},
|
||||
number<GemmPipeline::VectorSizeA>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_start,
|
||||
make_tuple(kargs.M, kargs.K),
|
||||
make_tuple(1, kargs.stride_A),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto b_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<LayoutB, tensor_layout::gemm::RowMajor>)
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_start,
|
||||
make_tuple(kargs.N, kargs.K),
|
||||
make_tuple(1, kargs.stride_B),
|
||||
number<GemmPipeline::AlignmentB>{},
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{ // Default NK layout
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_start,
|
||||
make_tuple(kargs.N, kargs.K),
|
||||
make_tuple(kargs.stride_B, 1),
|
||||
number<GemmPipeline::AlignmentB>{},
|
||||
number<GemmPipeline::VectorSizeB>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
@@ -122,10 +118,12 @@ struct GemmKernel
|
||||
auto a_pad_view = pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
|
||||
sequence < 0,
|
||||
GemmPipeline::kPadA ? 1 : 0 > {});
|
||||
// somehow clang-format is splitting below line into multiple.
|
||||
// clang-format off
|
||||
sequence<false, GemmPipeline::kPadA>{});
|
||||
// clang-format on
|
||||
|
||||
auto ABlockWindow = make_tile_window(
|
||||
auto a_block_window = make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
|
||||
{i_m, 0});
|
||||
@@ -133,10 +131,11 @@ struct GemmKernel
|
||||
auto b_pad_view = pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
|
||||
sequence < 0,
|
||||
GemmPipeline::kPadB ? 1 : 0 > {});
|
||||
// clang-format off
|
||||
sequence<false, GemmPipeline::kPadB>{});
|
||||
// clang-format on
|
||||
|
||||
auto BBlockWindow = make_tile_window(
|
||||
auto b_block_window = make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
|
||||
{i_n, 0});
|
||||
@@ -144,20 +143,21 @@ struct GemmKernel
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
const index_t num_loop = (kargs.K + TilePartitioner::kK - 1) / TilePartitioner::kK;
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
|
||||
|
||||
auto acc = GemmPipeline{}(ABlockWindow, BBlockWindow, num_loop, smem_ptr);
|
||||
|
||||
CODataType* c_start = static_cast<CODataType*>(kargs.c_ptr);
|
||||
// Run GEMM cooperatively by whole wokrgroup.
|
||||
auto c_block_tile =
|
||||
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
|
||||
|
||||
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
|
||||
auto c_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<LayoutC, tensor_layout::gemm::ColumnMajor>)
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
c_start,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(1, kargs.stride_C),
|
||||
number<GemmPipeline::AlignmentC>{},
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
number<GemmPipeline::VectorSizeC>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
@@ -165,8 +165,8 @@ struct GemmKernel
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
c_start,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
number<GemmPipeline::AlignmentC>{},
|
||||
make_tuple(1, kargs.stride_C),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
@@ -174,14 +174,15 @@ struct GemmKernel
|
||||
auto c_pad_view = pad_tensor_view(
|
||||
c_tensor_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
sequence < 0,
|
||||
GemmPipeline::kPadC ? 1 : 0 > {});
|
||||
auto CBlockWindow_pad = make_tile_window(
|
||||
// clang-format off
|
||||
sequence<false, GemmPipeline::kPadC>{});
|
||||
// clang-format on
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
EpiloguePipeline{}(CBlockWindow_pad, acc);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user