diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 81a9b08b70..089f968649 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -810,7 +810,6 @@ struct SelectedKernel {{ WarpTileN, // NPerXdl_ WarpTileK, // KPerXdl_ TransposeC, // isCTransposed_ - memory_operation, // MemoryOperation_ NumWaveGroups>; // kNumWaveGroups_ using GemmEpilogue = ck_tile::CShuffleEpilogue;""" @@ -827,15 +826,14 @@ struct SelectedKernel {{ DsLayout, CLayout, ElementWiseFn, - TilePartitioner::MPerBlock, // kM_ - TilePartitioner::NPerBlock, // kN_ + TileM, // kM_ + TileN, // kN_ WarpPerBlock_M, // MWave_ WarpPerBlock_N, // NWave_ WarpTileM, // MPerXdl_ WarpTileN, // NPerXdl_ WarpTileK, // KPerXdl_ - TransposeC, // isCTransposed_ - memory_operation>; // MemoryOperation_ + TransposeC>; // isCTransposed_ using GemmEpilogue = ck_tile::CShuffleEpilogue;""" return instance_code @@ -851,15 +849,14 @@ struct SelectedKernel {{ ck_tile::tuple<>, // DsLayout CLayout, ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, // kM_ - TilePartitioner::NPerBlock, // kN_ + TileM, // kM_ + TileN, // kN_ WarpPerBlock_M, // MWave_ WarpPerBlock_N, // NWave_ WarpTileM, // MPerXdl_ WarpTileN, // NPerXdl_ WarpTileK, // KPerXdl_ TransposeC, // isCTransposed_ - memory_operation, // MemoryOperation_ NumWaveGroups, // kNumWaveGroups_ false, // FixedVectorSize_ 1, // VectorSizeC_ @@ -879,8 +876,8 @@ struct SelectedKernel {{ ck_tile::tuple<>, // DsLayout CLayout, ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, // kM_ - TilePartitioner::NPerBlock, // kN_ + TileM, // kM_ + TileN, // kN_ kPadM, kPadN, WarpTileM, // kMPerXdl_ @@ -902,8 +899,8 @@ struct SelectedKernel {{ DsLayout, CLayout, ElementWiseFn, - TilePartitioner::MPerBlock, // kM_ - TilePartitioner::NPerBlock, // kN_ + TileM, // kM_ + TileN, // kN_ kPadM, kPadN, WarpTileM, // kMPerXdl_ @@ -925,8 +922,8 @@ struct SelectedKernel {{ ck_tile::tuple<>, // DsLayout CLayout, ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, // kM_ - TilePartitioner::NPerBlock, // kN_ + TileM, // kM_ + TileN, // kN_ kPadM, kPadN, WarpTileM, // kMPerXdl_