diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index f96f558101..6eb899ba8c 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -35,12 +35,12 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con using CodegenGemmTraits = ck_tile::TileGemmTraits, + ck_tile::tuple, CLayout>; - using CodegenPipelineProblem = ck_tile::GemmPipelineProblem, + ck_tile::tuple, AccDataType, CodegenFlatmmShape, CodegenGemmTraits>; @@ -49,8 +49,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con constexpr auto memory_operation = memory_operation_.value; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, + ck_tile::tuple, ck_tile::tuple<>, AccDataType, CDataType, diff --git a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp index 18b2fe6483..549d59e1ef 100644 --- a/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp @@ -18,7 +18,7 @@ struct BlockFlatmmASmemBSmemCRegV1 using BlockPolicy = remove_cvref_t; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using EDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape static constexpr auto I0 = number<0>(); @@ -61,7 +61,7 @@ struct BlockFlatmmASmemBSmemCRegV1 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); return c_block_tensor; } diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index d2e1bde58f..cc4249c4d1 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -64,13 +64,13 @@ struct FlatmmKernel using BlockGemmShape = remove_cvref_t; // TileFlatmmShape using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; // Below type is actually accumulation data type - the output of block GEMM. using CDataType = remove_cvref_t; @@ -81,6 +81,11 @@ struct FlatmmKernel static constexpr auto idxN = I1; static constexpr auto idxK = I2; + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 648b2b85bd..ad7ed18329 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -12,14 +12,21 @@ namespace ck_tile { template struct FlatmmPipelineAGmemBGmemCRegV1 { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using EDataType = remove_cvref_t; + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; using BlockFlatmm = remove_cvref_t())>; diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 5c33666ec4..df1f0d7873 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -351,7 +351,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy; return BlockFlatmmASmemBSmemCRegV1{}; diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index fd56200db5..0e6c111815 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -144,19 +144,18 @@ struct GroupedGemmKernel : public GemmKernel{ - {type_convert(gemm_descs[i].as_ptr[number<0>{}])}, - {type_convert(gemm_descs[i].bs_ptr[number<0>{}])}, - {}, - type_convert(gemm_descs[i].e_ptr), - M, - N, - K, - {stride_a}, - {stride_b}, - {}, - stride_e, - gemm_descs[i].k_batch}; + auto karg = GemmKernelArgs<>{{type_convert(gemm_descs[i].as_ptr[number<0>{}])}, + {type_convert(gemm_descs[i].bs_ptr[number<0>{}])}, + {}, + type_convert(gemm_descs[i].e_ptr), + M, + N, + K, + {stride_a}, + {stride_b}, + {}, + stride_e, + gemm_descs[i].k_batch}; gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); } @@ -199,10 +198,10 @@ struct GroupedGemmKernel : public GemmKernel(kargs.as_ptr[0]) + - splitk_batch_offset.as_k_split_offset[0]; - const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + - splitk_batch_offset.bs_k_split_offset[0]; + const ADataType* a_ptr = static_cast(kargs.as_ptr[number<0>{}]) + + splitk_batch_offset.as_k_split_offset[number<0>{}]; + const BDataType* b_ptr = static_cast(kargs.bs_ptr[number<0>{}]) + + splitk_batch_offset.bs_k_split_offset[number<0>{}]; EDataType* c_ptr = static_cast(kargs.e_ptr); // allocate LDS