mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
CK-Tile first draft of universal block gemm with interwave & intrawave scheduler (#1676)
* Block universal gemm. * Universal block gemm with interwave scheduler - draft. * Refactoring * Move a/b_warp_tiles into BlockGemmImpl * set BlockGemmImpl as a class member * Change tile size for more suitable to memory bound cases. * Introduce kKPerThread to WarpGemm * Add documentation comment. * Fix Interwave scheduler block gemm. * Add compute/memory friendly tile configuration. * Clean * New tile configurations in gemm mem example. * Add more static checks and fix loop order in block gemm. * Add more static checks and use warp gemm mfma dispatcher. * Add default scheduler block gemm. * Remove logging in example.
This commit is contained in:
@@ -247,8 +247,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
@@ -290,7 +290,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
{
|
||||
static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
|
||||
block_sync_lds();
|
||||
// block_gemm.LocalPrefetch();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
@@ -318,7 +318,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static_for<1, tail_num, 1>{}([&](auto prefetch_idx) {
|
||||
block_sync_lds();
|
||||
|
||||
// block_gemm.LocalPrefetch();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
@@ -331,14 +331,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
// block_gemm.LocalPrefetch();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
};
|
||||
|
||||
if constexpr(TailNum == TailNumber::One)
|
||||
{
|
||||
block_sync_lds();
|
||||
// block_gemm.LocalPrefetch();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Two)
|
||||
|
||||
@@ -11,6 +11,7 @@ namespace ck_tile {
|
||||
|
||||
enum struct GemmPipelineScheduler
|
||||
{
|
||||
Default,
|
||||
Intrawave,
|
||||
Interwave,
|
||||
};
|
||||
@@ -43,6 +44,7 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case ck_tile::GemmPipelineScheduler::Default: os << "Default"; break;
|
||||
case ck_tile::GemmPipelineScheduler::Intrawave: os << "Intrawave"; break;
|
||||
case ck_tile::GemmPipelineScheduler::Interwave: os << "Interwave"; break;
|
||||
default: os << "";
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -52,6 +53,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
// TODO: this 8 is AK1! should be a policy parameter!
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
|
||||
@@ -264,6 +266,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M2, M1 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
@@ -277,6 +282,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
{
|
||||
constexpr index_t M0 = BlockSize / get_warp_size();
|
||||
constexpr index_t M1 = MPerBlock / (M2 * M0);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M1, M2 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
@@ -350,6 +358,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
static_assert(N0 * N1 * N2 == NPerBlock,
|
||||
"Incorrect N0, N1, N2 configuration! "
|
||||
"N0, N1, N2 must cover whole NPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
@@ -364,7 +375,9 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
{
|
||||
constexpr index_t N0 = BlockSize / get_warp_size();
|
||||
constexpr index_t N1 = NPerBlock / (N2 * N0);
|
||||
|
||||
static_assert(N0 * N1 * N2 == NPerBlock,
|
||||
"Incorrect N0, N1, N2 configuration! "
|
||||
"N0, N1, N2 must cover whole NPerBlock!");
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
@@ -475,9 +488,28 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy;
|
||||
constexpr bool TransposeC = false;
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
constexpr auto I2 = number<2>{};
|
||||
|
||||
return BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{};
|
||||
using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
AccDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
TransposeC>;
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -33,6 +33,8 @@ struct GemmPipelineProblemBase
|
||||
static constexpr bool kPadN = GemmTraits::kPadN;
|
||||
static constexpr bool kPadK = GemmTraits::kPadK;
|
||||
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
|
||||
{
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
|
||||
Reference in New Issue
Block a user