mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK-Tile] Enable vectorized reads on all layouts & improve perf. (#1835)
* Refactor universal gemm policy. * Adapt example to refactor changes. * Introduce static encoding pattern * Adding shuffled encoding patterns. * Fix err in reverse tuple. * Add transpose_tile2d * Small refactoring + doc * Enable reading on contiguous dimension in all layouts. * Transpose A/B register tile if needed for comp v3 pipeline. * Take contiguous dim size when calculating dram vector load size. * A/B smem pack size taken from WarpGemm attributes * Update B LDS layout and setup tile distribution pattern at class level. * Fix static assert. * Fix errors in examples. * Formatting & fix IsTranspose * Fix VectorSize & refactor. * Add error loging messages. * Fix VecLoadSize and TranspseC for mem pipeline. * Update unit-tests & disable mem pipeline. * Clang format * Update include/ck_tile/core/tensor/tile_window.hpp Co-authored-by: jakpiase <jakub.piasecki@amd.com> * Fix compilation and reviewers comments. * Refactor unit-test. Fallback to non-universal gemm. Need to use GemmPipelineAGmemBGmemCRegV1 for now, since GemmKernel is now supporting also non-K major vector reads. --------- Co-authored-by: jakpiase <jakub.piasecki@amd.com>
This commit is contained in:
@@ -16,6 +16,7 @@ enum struct GemmPipelineType
|
||||
Mem,
|
||||
Comp
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmPipeline : public ::testing::Test
|
||||
{
|
||||
@@ -51,6 +52,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
constexpr bool kPadN = PadN;
|
||||
constexpr bool kPadK = PadK;
|
||||
|
||||
// TODO: For now - but this should also be a test parameter
|
||||
constexpr bool TransposeC = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// ===============================================
|
||||
@@ -65,14 +69,16 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::
|
||||
TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>;
|
||||
|
||||
using BaseGemmPipeline = std::conditional_t<
|
||||
PipelineType == GemmPipelineType::Mem,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<
|
||||
ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline =
|
||||
std::conditional_t<PipelineType == GemmPipelineType::Mem,
|
||||
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
|
||||
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
|
||||
@@ -84,26 +90,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
|
||||
using GemmPipeline =
|
||||
std::conditional_t<PipelineType == GemmPipelineType::Mem,
|
||||
ck_tile::GemmPipelineAgBgCrMem<
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
Traits,
|
||||
Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
Traits,
|
||||
Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
Scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
PipelineType == GemmPipelineType::Mem,
|
||||
ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem,
|
||||
ck_tile::UniversalGemmPipelineAgBgCrPolicy>,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem,
|
||||
ck_tile::UniversalGemmPipelineAgBgCrPolicy>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -129,70 +131,94 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
// Tail pipeline One to Seven
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
if constexpr(PipelineType == GemmPipelineType::Comp)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "For compute pipeline tail number should always be Full, but have \""
|
||||
<< tail_num << "\" which is not supported! PrefetchStages: "
|
||||
<< BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":"
|
||||
<< __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
|
||||
if constexpr(PipelineType == GemmPipelineType::Mem)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Two)
|
||||
// Tail pipeline One to Seven
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Two>{});
|
||||
ck_tile::TailNumber::One>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Three>{});
|
||||
ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Four)
|
||||
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Four>{});
|
||||
if(tail_num == ck_tile::TailNumber::Two)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Five)
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Five>{});
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Six)
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Six>{});
|
||||
if(tail_num == ck_tile::TailNumber::Four)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Four>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Seven)
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Seven>{});
|
||||
if(tail_num == ck_tile::TailNumber::Five)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Five>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Six)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Six>{});
|
||||
}
|
||||
}
|
||||
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Seven)
|
||||
{
|
||||
Run(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber,
|
||||
ck_tile::TailNumber::Seven>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user