[CK-Tile] Universal gemm memory bound pipeline (#1558)

* CK-Tile GEMM with memory bound pipeline.

* Memory bound gemm pipeline.

* Fix not closed namespace.

* Block gemm mem pipeline draft.

* Do not use ck_tile:: within ck_tile namespace.

* Refactoring & Move Layout info to pipeline problem.

* Get hot loop and TailNum information before lunching kernel.

* Fixes in pipeline.

* Add comment to load_tile_raw and change variable naming style.

* Few small changes & formatting.

* Do not use macro.

* Add gtests.

* Use AccDataType for Output of MFMA instruction.

* Formatting.

* Refactor gemm examples.

* Switch over to current block gemm.

* Use currently available pipeline policy.

* Refactoring and review comment.s

* Fixes after merge.

* Add missing include.

* Add load tile overload which accepts output tensor as parameter.

* This give 8% perf boost at the cost of using more registers.

* Rename example.

* Small changes.

* Fix compilation err and lower K.

* Support different layouts for A/B

* Fix vector size for different layouts.

* Rename Alignment into VectorSize

* Unblock tests.
This commit is contained in:
Adam Osewski
2024-10-30 10:05:15 +01:00
committed by GitHub
parent 3d60953477
commit 24d996aae1
29 changed files with 1660 additions and 561 deletions

View File

@@ -3,11 +3,12 @@
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include <iostream>
#include <string>
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
@@ -17,20 +18,19 @@ struct GemmKernel
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using LayoutA = remove_cvref_t<typename GemmPipeline::LayoutA>;
using LayoutB = remove_cvref_t<typename GemmPipeline::LayoutB>;
using LayoutC = remove_cvref_t<typename GemmPipeline::LayoutC>;
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size)
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
return TilePartitioner::GridSize(M_size, N_size, Batch_size);
return TilePartitioner::GridSize(M, N, KBatch);
}
__host__ static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
@@ -40,34 +40,30 @@ struct GemmKernel
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
float epsilon;
ck_tile::index_t M;
ck_tile::index_t N;
ck_tile::index_t K;
ck_tile::index_t stride_A;
ck_tile::index_t stride_B;
ck_tile::index_t stride_C;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
CK_TILE_HOST static constexpr GemmCommonKargs MakeKargs(const void* a_ptr,
const void* b_ptr,
void* c_ptr,
float epsilon,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C)
index_t M,
index_t N,
index_t K,
index_t stride_A,
index_t stride_B,
index_t stride_C)
{
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, epsilon, M, N, K, stride_A, stride_B, stride_C};
return GemmCommonKargs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return ck_tile::max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
@@ -78,43 +74,43 @@ struct GemmKernel
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views
auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutA, tensor_layout::gemm::ColumnMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A),
number<GemmPipeline::AlignmentA>{},
number<1>{});
}
else
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(kargs.stride_A, 1),
number<GemmPipeline::AlignmentA>{},
number<GemmPipeline::VectorSizeA>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_start,
make_tuple(kargs.M, kargs.K),
make_tuple(1, kargs.stride_A),
number<1>{},
number<1>{});
}
}();
auto b_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutB, tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
b_start,
make_tuple(kargs.N, kargs.K),
make_tuple(1, kargs.stride_B),
number<GemmPipeline::AlignmentB>{},
number<1>{},
number<1>{});
}
else
{ // Default NK layout
{
return make_naive_tensor_view<address_space_enum::global>(
b_start,
make_tuple(kargs.N, kargs.K),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::AlignmentB>{},
number<GemmPipeline::VectorSizeB>{},
number<1>{});
}
}();
@@ -122,10 +118,12 @@ struct GemmKernel
auto a_pad_view = pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence < 0,
GemmPipeline::kPadA ? 1 : 0 > {});
// somehow clang-format is splitting below line into multiple.
// clang-format off
sequence<false, GemmPipeline::kPadA>{});
// clang-format on
auto ABlockWindow = make_tile_window(
auto a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0});
@@ -133,10 +131,11 @@ struct GemmKernel
auto b_pad_view = pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence < 0,
GemmPipeline::kPadB ? 1 : 0 > {});
// clang-format off
sequence<false, GemmPipeline::kPadB>{});
// clang-format on
auto BBlockWindow = make_tile_window(
auto b_block_window = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0});
@@ -144,20 +143,21 @@ struct GemmKernel
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
const index_t num_loop = (kargs.K + TilePartitioner::kK - 1) / TilePartitioner::kK;
const index_t num_loop = TilePartitioner::GetLoopNum(kargs.K);
auto acc = GemmPipeline{}(ABlockWindow, BBlockWindow, num_loop, smem_ptr);
CODataType* c_start = static_cast<CODataType*>(kargs.c_ptr);
// Run GEMM cooperatively by whole wokrgroup.
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<LayoutC, tensor_layout::gemm::ColumnMajor>)
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<GemmPipeline::AlignmentC>{},
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<1>{});
}
else
@@ -165,8 +165,8 @@ struct GemmKernel
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::AlignmentC>{},
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
@@ -174,14 +174,15 @@ struct GemmKernel
auto c_pad_view = pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence < 0,
GemmPipeline::kPadC ? 1 : 0 > {});
auto CBlockWindow_pad = make_tile_window(
// clang-format off
sequence<false, GemmPipeline::kPadC>{});
// clang-format on
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n});
EpiloguePipeline{}(CBlockWindow_pad, acc);
EpiloguePipeline{}(c_block_window, c_block_tile);
}
};