This commit is contained in:
Mateusz Ozga
2025-07-01 21:48:43 +00:00
parent 1e64412d35
commit ce8eb916a2
6 changed files with 48 additions and 37 deletions

View File

@@ -35,12 +35,12 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
using CodegenGemmTraits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
ALayout,
BLayout,
ck_tile::tuple<ALayout>,
ck_tile::tuple<BLayout>,
CLayout>;
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
BDataType,
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ck_tile::tuple<ADataType>,
ck_tile::tuple<BDataType>,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits>;
@@ -49,8 +49,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::CShuffleEpilogueProblem<ck_tile::tuple<ADataType>,
ck_tile::tuple<BDataType>,
ck_tile::tuple<>,
AccDataType,
CDataType,