mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[CK Tile] Improve the Layout, Padding, and Alignment features of CK Tile GEMM (#1651)
* Finished the feature * Modified the test file * Test case update * addresss comment * Addressed the review comment * Fixed the CI error
This commit is contained in:
@@ -8,7 +8,10 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
# The basic pipeline method on the gemm calculation
|
||||
make tile_example_gemm_basic -j
|
||||
# The memory bound pipeline on the gemm calculation
|
||||
make tile_example_gemm_mem_pipeline -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_basic`
|
||||
|
||||
|
||||
@@ -17,10 +17,11 @@
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadA = true;
|
||||
constexpr bool kPadB = true;
|
||||
constexpr bool kPadC = true;
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr bool kTilePermute = false;
|
||||
// The rank and permutation will also be generate out by the CodeGen part.
|
||||
constexpr ck_tile::index_t kOutputRank = 2;
|
||||
@@ -56,8 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
CShuffleEpilogue,
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
kPadA,
|
||||
kPadB,
|
||||
kPadM,
|
||||
kPadN,
|
||||
kTilePermute,
|
||||
kOutputRank,
|
||||
1,
|
||||
@@ -65,13 +66,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
TilePartitioner::kM,
|
||||
TilePartitioner::kN>>,
|
||||
ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>>;
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using CodegenPipelineProblem = ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
|
||||
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy<ALayout, BLayout, CLayout>;
|
||||
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
|
||||
using CodegenGemmPipeline =
|
||||
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
|
||||
@@ -31,9 +31,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadA = true;
|
||||
constexpr bool kPadB = true;
|
||||
constexpr bool kPadC = true;
|
||||
constexpr bool kPadM = true;
|
||||
constexpr bool kPadN = true;
|
||||
constexpr bool kPadK = true;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
@@ -46,9 +46,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
|
||||
|
||||
using GemmEpilogue = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>;
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
|
||||
|
||||
Reference in New Issue
Block a user