mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit '5f1ad09b610cb0e083f63988479ab022bda70588' into develop
This commit is contained in:
@@ -18,6 +18,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added support for Split K for grouped convolution backward data.
|
||||
* Added logit soft-capping support for fMHA forward kernels.
|
||||
* Added benchmarking support for tile engine GEMM.
|
||||
* Added Ping-pong scheduler support for GEMM operation along the K dimension.
|
||||
* Added rotating buffer feature for CK_Tile GEMM.
|
||||
|
||||
### Optimized
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
||||
|
||||
#ifndef CK_TILE_PIPELINE_DEFAULT
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
|
||||
@@ -31,6 +32,10 @@
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV5
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV5
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
|
||||
#else
|
||||
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
|
||||
#endif
|
||||
@@ -51,7 +56,8 @@ struct GemmConfig
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
#endif
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
@@ -67,7 +73,8 @@ struct GemmConfig
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
@@ -83,7 +90,29 @@ struct GemmConfig
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V5)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 2;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
// Available wavegroups will be split into `NumWaveGroups` and each of these wavegroups
|
||||
// will be responsible for specific jobs. For instance, perform Global Memory read operations,
|
||||
// perform block-gemm operation etc...
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 2;
|
||||
#endif
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
|
||||
@@ -50,7 +50,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
Persistent>;
|
||||
Persistent,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
@@ -96,7 +97,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -190,7 +193,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
|
||||
@@ -56,19 +56,24 @@ template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
tile_distribution_pattern DistributionPattern>
|
||||
tile_distribution_pattern DistributionPattern,
|
||||
index_t NumWaveGroups = 1>
|
||||
struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern
|
||||
{
|
||||
};
|
||||
|
||||
// Thread raked
|
||||
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::thread_raked>
|
||||
: public TileDistributionEncodingPattern
|
||||
tile_distribution_pattern::thread_raked,
|
||||
NumWaveGroups> : public TileDistributionEncodingPattern
|
||||
{
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
@@ -83,45 +88,76 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
static constexpr index_t Y1 = warp_size / X0;
|
||||
static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
|
||||
|
||||
static constexpr index_t Y0 = num_warps;
|
||||
static constexpr index_t Y0 = num_warps / NumWaveGroups;
|
||||
// YPerWarp = YPerTile / Y0;
|
||||
// Y2 = YPerWarp / Y1;
|
||||
static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
|
||||
|
||||
static_assert(X0 * Y1 * Y0 == BlockSize, "X0 * warp_ys * Y0 must cover whole workgroup!");
|
||||
static_assert(X0 * Y1 * Y0 * NumWaveGroups == BlockSize,
|
||||
"X0 * warp_ys * Y0 must cover whole workgroup!");
|
||||
static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
if constexpr(NumWaveGroups != 1)
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<Y0>,
|
||||
tuple<sequence<Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<0>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{});
|
||||
if constexpr(NumWaveGroups != 1)
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<Y0>,
|
||||
tuple<sequence<X0, X1>, sequence<Y1, Y2>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Warp raked
|
||||
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::warp_raked>
|
||||
: public TileDistributionEncodingPattern
|
||||
tile_distribution_pattern::warp_raked,
|
||||
NumWaveGroups> : public TileDistributionEncodingPattern
|
||||
{
|
||||
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
@@ -164,13 +200,17 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
};
|
||||
|
||||
// Block raked
|
||||
template <index_t BlockSize, index_t YPerTile, index_t XPerTile, index_t VecSize>
|
||||
template <index_t BlockSize,
|
||||
index_t YPerTile,
|
||||
index_t XPerTile,
|
||||
index_t VecSize,
|
||||
index_t NumWaveGroups>
|
||||
struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
YPerTile,
|
||||
XPerTile,
|
||||
VecSize,
|
||||
tile_distribution_pattern::block_raked>
|
||||
: public TileDistributionEncodingPattern
|
||||
tile_distribution_pattern::block_raked,
|
||||
NumWaveGroups> : public TileDistributionEncodingPattern
|
||||
{
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
|
||||
@@ -23,7 +23,8 @@ template <typename ADataType_,
|
||||
index_t NPerXdl_,
|
||||
index_t KPerXdl_,
|
||||
bool isCTransposed_,
|
||||
memory_operation_enum MemoryOperation_>
|
||||
memory_operation_enum MemoryOperation_,
|
||||
index_t kNumWaveGroups_ = 1>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
@@ -41,6 +42,7 @@ struct CShuffleEpilogueProblem
|
||||
static constexpr index_t KPerXdl = KPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
@@ -236,7 +238,8 @@ struct CShuffleEpilogue
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked>;
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
|
||||
@@ -31,6 +31,8 @@
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
|
||||
@@ -60,52 +60,105 @@ struct BlockGemmARegBRegCRegV1
|
||||
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
|
||||
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
|
||||
|
||||
static constexpr index_t MWarp = Traits::MWarp;
|
||||
static constexpr index_t NWarp = Traits::NWarp;
|
||||
static constexpr index_t MWarp = Traits::MWarp;
|
||||
static constexpr index_t NWarp = Traits::NWarp;
|
||||
static constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
return a_block_dstr_encode;
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
return c_block_dstr_encode;
|
||||
return c_block_dstr_encode;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
return c_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
@@ -201,19 +254,38 @@ struct BlockGemmARegBRegCRegV1
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
|
||||
@@ -644,6 +644,7 @@ struct GemmKernel
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
template <bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
@@ -671,11 +672,15 @@ struct GemmKernel
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, smem_ptr_0);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I2);
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, smem_ptr_0);
|
||||
EpiloguePipeline{}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -772,7 +777,9 @@ struct GemmKernel
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<CDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
|
||||
RunGemm<scheduler_type>(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,7 +71,8 @@ struct GemmPipelineAgBgCrImplBase
|
||||
template <typename ADramBlockWindowTmp, typename ALdsTensorView, typename ALdsLoadTileDistr>
|
||||
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const ALdsTensorView& a_lds_block_view,
|
||||
const ALdsLoadTileDistr&) const
|
||||
const ALdsLoadTileDistr&,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
@@ -82,7 +83,7 @@ struct GemmPipelineAgBgCrImplBase
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
a_dram_block_window_tmp.get_window_origin(),
|
||||
a_dram_block_window_tmp.get_window_origin() + offset,
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A LDS tile window for store
|
||||
@@ -103,7 +104,8 @@ struct GemmPipelineAgBgCrImplBase
|
||||
template <typename BDramBlockWindowTmp, typename BLdsTensorView, typename BLdsLoadTileDistr>
|
||||
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BLdsTensorView& b_lds_block_view,
|
||||
const BLdsLoadTileDistr&) const
|
||||
const BLdsLoadTileDistr&,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
@@ -113,7 +115,7 @@ struct GemmPipelineAgBgCrImplBase
|
||||
auto b_copy_dram_window =
|
||||
make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
b_dram_block_window_tmp.get_window_origin(),
|
||||
b_dram_block_window_tmp.get_window_origin() + offset,
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// TODO: Do we really need those two tile windows???
|
||||
|
||||
@@ -143,6 +143,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
|
||||
@@ -134,6 +134,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
|
||||
@@ -0,0 +1,379 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed Tensor: register
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAgBgCrCompV5
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
|
||||
{
|
||||
return TailNumber::Empty;
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV5DefaultPolicy>
|
||||
struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompV5<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
static constexpr auto TailNum = Problem::TailNum;
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr index_t NumWarps = BlockGemmShape::NumWarps;
|
||||
static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{});
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AgBgCrCompV5", BlockSize,
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
|
||||
{
|
||||
return Policy::template IsTransposeC<Problem>();
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"Data Type conflict on A and B matrix input data type.");
|
||||
|
||||
static_assert(
|
||||
KPerBlock % ((NumWarps / 2) * KTileSize) == 0,
|
||||
"Ping Pong Warps, TileSize and Block Size for K dimensions does not match.");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
index_t warp_id = get_warp_id();
|
||||
index_t operation_id =
|
||||
__builtin_amdgcn_readfirstlane(get_warp_id()); // 0 - Memory read, 1 - block-gemm
|
||||
|
||||
auto a_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, KPerBlock);
|
||||
auto b_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, KPerBlock);
|
||||
|
||||
auto tensor_views =
|
||||
Base::GetABLdsTensorViews(static_cast<void*>(static_cast<char*>(p_smem_0)));
|
||||
auto& a_lds_block = tensor_views.get(number<0>{});
|
||||
auto& b_lds_block = tensor_views.get(number<1>{});
|
||||
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
auto a_windows = Base::GetAWindows(
|
||||
a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr, a_offset);
|
||||
auto& a_copy_dram_window = a_windows.get(number<0>{});
|
||||
auto& a_copy_lds_window = a_windows.get(number<1>{});
|
||||
auto& a_lds_window = a_windows.get(number<2>{});
|
||||
|
||||
auto b_windows = Base::GetBWindows(
|
||||
b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr, b_offset);
|
||||
auto& b_copy_dram_window = b_windows.get(number<0>{});
|
||||
auto& b_copy_lds_window = b_windows.get(number<1>{});
|
||||
auto& b_lds_window = b_windows.get(number<2>{});
|
||||
|
||||
// DRAM window steps.
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock * NumWarps, 0)
|
||||
: make_array(0, KPerBlock * NumWarps);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock * NumWarps, 0)
|
||||
: make_array(0, KPerBlock * NumWarps);
|
||||
|
||||
constexpr auto AGemmTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeABlockDistributionEncode())){};
|
||||
constexpr auto BGemmTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeBBlockDistributionEncode())){};
|
||||
|
||||
using AGemmTile = decltype(make_static_distributed_tensor<ADataType>(AGemmTileDistr));
|
||||
using BGemmTile = decltype(make_static_distributed_tensor<BDataType>(BGemmTileDistr));
|
||||
AGemmTile a_tile_0, a_tile_1;
|
||||
BGemmTile b_tile_0, b_tile_1;
|
||||
|
||||
// Register tile for A and B.
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
ABlockTile a_global_load_tile;
|
||||
BBlockTile b_global_load_tile;
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile_0 = block_gemm.MakeCBlockTile();
|
||||
auto c_block_tile_1 = block_gemm.MakeCBlockTile();
|
||||
|
||||
CDataType* __restrict__ p_c_lds = static_cast<CDataType*>(p_smem_0);
|
||||
auto c_lds_block_0 =
|
||||
make_naive_tensor_view<address_space_enum::lds>(p_c_lds,
|
||||
make_tuple(MPerBlock, NPerBlock),
|
||||
make_tuple(NPerBlock, 1),
|
||||
number<BlockGemm::Traits::KPack>{},
|
||||
number<1>{});
|
||||
auto c_window_0 = make_tile_window(c_lds_block_0,
|
||||
make_tuple(number<MPerBlock>{}, number<NPerBlock>{}),
|
||||
{0, 0},
|
||||
c_block_tile_1.get_tile_distribution());
|
||||
|
||||
// initialize C
|
||||
if(warp_id == 0)
|
||||
{
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_1);
|
||||
}
|
||||
|
||||
// define ping, pong steps here as lambda functions.
|
||||
auto MemoryOpsStep = [&](auto idx) {
|
||||
// Memory read half here.
|
||||
Base::GlobalPrefetch(
|
||||
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetch(
|
||||
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func);
|
||||
}
|
||||
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func);
|
||||
}
|
||||
|
||||
if(idx == 0)
|
||||
{
|
||||
Base::LocalPrefetch(a_tile_0, a_lds_window);
|
||||
Base::LocalPrefetch(b_tile_0, b_lds_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefetch(a_tile_1, a_lds_window);
|
||||
Base::LocalPrefetch(b_tile_1, b_lds_window);
|
||||
}
|
||||
};
|
||||
|
||||
auto ComputeStep = [&](auto idx) {
|
||||
if(idx == 0)
|
||||
{
|
||||
block_gemm(c_block_tile_0, a_tile_0, b_tile_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_gemm(c_block_tile_1, a_tile_1, b_tile_1);
|
||||
}
|
||||
};
|
||||
|
||||
if(operation_id == 0)
|
||||
{
|
||||
MemoryOpsStep(warp_id);
|
||||
}
|
||||
|
||||
index_t num_compute_steps = __builtin_amdgcn_readfirstlane(num_loop);
|
||||
while(num_compute_steps > 1)
|
||||
{
|
||||
block_sync_lds();
|
||||
operation_id = (operation_id + 1) % NumWaveGroups;
|
||||
|
||||
if(operation_id == 0)
|
||||
{
|
||||
MemoryOpsStep(warp_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
ComputeStep(warp_id);
|
||||
}
|
||||
num_compute_steps -= 1;
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
if(operation_id == 0)
|
||||
{
|
||||
ComputeStep(warp_id);
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
if(warp_id == 1)
|
||||
{
|
||||
store_tile(c_window_0, c_block_tile_1);
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
if(warp_id == 0)
|
||||
{
|
||||
load_tile(c_block_tile_1, c_window_0);
|
||||
|
||||
constexpr auto s_spans = decltype(c_block_tile_0)::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
auto idx2 = make_tuple(idx0, idx1);
|
||||
c_block_tile_0(idx2) += c_block_tile_1(idx2);
|
||||
});
|
||||
});
|
||||
}
|
||||
return c_block_tile_0;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_0) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
{
|
||||
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// Default policy for GemmPipelineAGmemBGmemCregComputeV5, except the block gemm method, it shares
|
||||
// the same vector size implementation, SmemSize, Global memory tile distiribution as the
|
||||
// UniversalGemm Pipeline Policy.
|
||||
// Default policy class should not be templated, put template on
|
||||
// member functions instead.
|
||||
struct GemmPipelineAgBgCrCompV5DefaultPolicy
|
||||
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompV5DefaultPolicy>
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
AccDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeC()
|
||||
{
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
|
||||
return integer_least_multiple(sizeof(typename Problem::CDataType) * MPerBlock * NPerBlock,
|
||||
16);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
|
||||
constexpr index_t smem_size_c = GetSmemSizeC<Problem>();
|
||||
|
||||
return smem_size_a + smem_size_b >= smem_size_c ? (smem_size_a + smem_size_b)
|
||||
: (smem_size_c);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -188,6 +188,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
// Where is the right place for HasHotLoop and TailNum ???
|
||||
static constexpr bool HasHotLoop = Problem::HasHotLoop;
|
||||
|
||||
@@ -198,6 +198,8 @@ struct UniversalGemmPipelineProblem
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -426,10 +426,11 @@ struct UniversalGemmBasePolicy
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
// Tile: MPerBlock X KPerBlock
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
@@ -438,7 +439,8 @@ struct UniversalGemmBasePolicy
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
ATileAccessPattern>;
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
// Tile: KPerBlock X MPerBlock
|
||||
@@ -448,7 +450,8 @@ struct UniversalGemmBasePolicy
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
ATileAccessPattern>;
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
}
|
||||
@@ -458,10 +461,11 @@ struct UniversalGemmBasePolicy
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
// Tile: KPerBlock X NPerBlock
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
@@ -470,7 +474,8 @@ struct UniversalGemmBasePolicy
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern>;
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
// Tile: NPerBlock X KPerBlock
|
||||
@@ -480,7 +485,8 @@ struct UniversalGemmBasePolicy
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern>;
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
}
|
||||
}
|
||||
@@ -490,16 +496,18 @@ struct UniversalGemmBasePolicy
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
ATileAccessPattern>;
|
||||
ATileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
|
||||
}
|
||||
|
||||
@@ -508,16 +516,18 @@ struct UniversalGemmBasePolicy
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern>;
|
||||
BTileAccessPattern,
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
|
||||
}
|
||||
|
||||
|
||||
@@ -39,7 +39,8 @@ template <bool kPadM_,
|
||||
typename CLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool UseStructuredSparsity_ = false,
|
||||
bool UsePersistentKernel_ = false>
|
||||
bool UsePersistentKernel_ = false,
|
||||
index_t NumWaveGroups_ = 1>
|
||||
struct TileGemmUniversalTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
@@ -55,6 +56,7 @@ struct TileGemmUniversalTraits
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
static constexpr bool UseStructuredSparsity = UseStructuredSparsity_;
|
||||
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
|
||||
static constexpr index_t NumWaveGroups = NumWaveGroups_;
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
|
||||
Reference in New Issue
Block a user