This commit is contained in:
Mateusz Ozga
2025-07-01 21:48:43 +00:00
parent 1e64412d35
commit ce8eb916a2
6 changed files with 48 additions and 37 deletions

View File

@@ -35,12 +35,12 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
using CodegenGemmTraits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
ALayout,
BLayout,
ck_tile::tuple<ALayout>,
ck_tile::tuple<BLayout>,
CLayout>;
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
BDataType,
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ck_tile::tuple<ADataType>,
ck_tile::tuple<BDataType>,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits>;
@@ -49,8 +49,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::CShuffleEpilogueProblem<ck_tile::tuple<ADataType>,
ck_tile::tuple<BDataType>,
ck_tile::tuple<>,
AccDataType,
CDataType,

View File

@@ -18,7 +18,7 @@ struct BlockFlatmmASmemBSmemCRegV1
using BlockPolicy = remove_cvref_t<BlockPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using EDataType = remove_cvref_t<typename Problem::EDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
static constexpr auto I0 = number<0>();
@@ -61,7 +61,7 @@ struct BlockFlatmmASmemBSmemCRegV1
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
auto c_block_tensor = make_static_distributed_tensor<EDataType>(c_block_dstr);
return c_block_tensor;
}

View File

@@ -64,13 +64,13 @@ struct FlatmmKernel
using BlockGemmShape =
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
using CLayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
using AsLayout = remove_cvref_t<typename FlatmmPipeline::AsLayout>;
using BsLayout = remove_cvref_t<typename FlatmmPipeline::BsLayout>;
using CLayout = remove_cvref_t<typename FlatmmPipeline::ELayout>;
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
using AsDataType = remove_cvref_t<typename FlatmmPipeline::AsDataType>;
using BsDataType = remove_cvref_t<typename FlatmmPipeline::BsDataType>;
// Below type is actually accumulation data type - the output of block GEMM.
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
@@ -81,6 +81,11 @@ struct FlatmmKernel
static constexpr auto idxN = I1;
static constexpr auto idxK = I2;
using ADataType = remove_cvref_t<std::tuple_element_t<I0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<I0, BsDataType>>;
using ALayout = remove_cvref_t<std::tuple_element_t<I0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<I0, BsLayout>>;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off

View File

@@ -12,14 +12,21 @@ namespace ck_tile {
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
struct FlatmmPipelineAGmemBGmemCRegV1
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
using EDataType = remove_cvref_t<typename Problem::EDataType>;
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using AsLayout = remove_cvref_t<typename Problem::AsLayout>;
using BsLayout = remove_cvref_t<typename Problem::BsLayout>;
using ELayout = remove_cvref_t<typename Problem::ELayout>;
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
using BlockFlatmm =
remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;

View File

@@ -351,7 +351,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
using BlockFlatmmPolicy =
BlockFlatmmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
typename Problem::EDataType,
BlockWarps,
WarpGemm>;
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};

View File

@@ -144,19 +144,18 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
grid_size += grid_size_grp;
auto karg = GemmKernelArgs<>{
{type_convert<const ADataType*>(gemm_descs[i].as_ptr[number<0>{}])},
{type_convert<const BDataType*>(gemm_descs[i].bs_ptr[number<0>{}])},
{},
type_convert<EDataType*>(gemm_descs[i].e_ptr),
M,
N,
K,
{stride_a},
{stride_b},
{},
stride_e,
gemm_descs[i].k_batch};
auto karg = GemmKernelArgs<>{{type_convert<const ADataType*>(gemm_descs[i].as_ptr[number<0>{}])},
{type_convert<const BDataType*>(gemm_descs[i].bs_ptr[number<0>{}])},
{},
type_convert<EDataType*>(gemm_descs[i].e_ptr),
M,
N,
K,
{stride_a},
{stride_b},
{},
stride_e,
gemm_descs[i].k_batch};
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
}
@@ -199,10 +198,10 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) +
splitk_batch_offset.as_k_split_offset[0];
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) +
splitk_batch_offset.bs_k_split_offset[0];
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[number<0>{}]) +
splitk_batch_offset.as_k_split_offset[number<0>{}];
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[number<0>{}]) +
splitk_batch_offset.bs_k_split_offset[number<0>{}];
EDataType* c_ptr = static_cast<EDataType*>(kargs.e_ptr);
// allocate LDS