mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
Fix CI
This commit is contained in:
@@ -35,12 +35,12 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
using CodegenGemmTraits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<ALayout>,
|
||||
ck_tile::tuple<BLayout>,
|
||||
CLayout>;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ck_tile::tuple<ADataType>,
|
||||
ck_tile::tuple<BDataType>,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits>;
|
||||
@@ -49,8 +49,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::CShuffleEpilogueProblem<ck_tile::tuple<ADataType>,
|
||||
ck_tile::tuple<BDataType>,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
|
||||
@@ -18,7 +18,7 @@ struct BlockFlatmmASmemBSmemCRegV1
|
||||
using BlockPolicy = remove_cvref_t<BlockPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using EDataType = remove_cvref_t<typename Problem::EDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
@@ -61,7 +61,7 @@ struct BlockFlatmmASmemBSmemCRegV1
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
auto c_block_tensor = make_static_distributed_tensor<EDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
|
||||
@@ -64,13 +64,13 @@ struct FlatmmKernel
|
||||
using BlockGemmShape =
|
||||
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
|
||||
using AsLayout = remove_cvref_t<typename FlatmmPipeline::AsLayout>;
|
||||
using BsLayout = remove_cvref_t<typename FlatmmPipeline::BsLayout>;
|
||||
using CLayout = remove_cvref_t<typename FlatmmPipeline::ELayout>;
|
||||
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
using AsDataType = remove_cvref_t<typename FlatmmPipeline::AsDataType>;
|
||||
using BsDataType = remove_cvref_t<typename FlatmmPipeline::BsDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
@@ -81,6 +81,11 @@ struct FlatmmKernel
|
||||
static constexpr auto idxN = I1;
|
||||
static constexpr auto idxK = I2;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<I0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<I0, BsDataType>>;
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<I0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<I0, BsLayout>>;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
@@ -12,14 +12,21 @@ namespace ck_tile {
|
||||
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
|
||||
struct FlatmmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
|
||||
using EDataType = remove_cvref_t<typename Problem::EDataType>;
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayout>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayout>;
|
||||
using ELayout = remove_cvref_t<typename Problem::ELayout>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using BlockFlatmm =
|
||||
remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
|
||||
|
||||
@@ -351,7 +351,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
using BlockFlatmmPolicy =
|
||||
BlockFlatmmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
typename Problem::EDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
|
||||
|
||||
@@ -144,19 +144,18 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
|
||||
grid_size += grid_size_grp;
|
||||
|
||||
auto karg = GemmKernelArgs<>{
|
||||
{type_convert<const ADataType*>(gemm_descs[i].as_ptr[number<0>{}])},
|
||||
{type_convert<const BDataType*>(gemm_descs[i].bs_ptr[number<0>{}])},
|
||||
{},
|
||||
type_convert<EDataType*>(gemm_descs[i].e_ptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
{stride_a},
|
||||
{stride_b},
|
||||
{},
|
||||
stride_e,
|
||||
gemm_descs[i].k_batch};
|
||||
auto karg = GemmKernelArgs<>{{type_convert<const ADataType*>(gemm_descs[i].as_ptr[number<0>{}])},
|
||||
{type_convert<const BDataType*>(gemm_descs[i].bs_ptr[number<0>{}])},
|
||||
{},
|
||||
type_convert<EDataType*>(gemm_descs[i].e_ptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
{stride_a},
|
||||
{stride_b},
|
||||
{},
|
||||
stride_e,
|
||||
gemm_descs[i].k_batch};
|
||||
|
||||
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
|
||||
}
|
||||
@@ -199,10 +198,10 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
|
||||
|
||||
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
|
||||
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) +
|
||||
splitk_batch_offset.as_k_split_offset[0];
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) +
|
||||
splitk_batch_offset.bs_k_split_offset[0];
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[number<0>{}]) +
|
||||
splitk_batch_offset.as_k_split_offset[number<0>{}];
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[number<0>{}]) +
|
||||
splitk_batch_offset.bs_k_split_offset[number<0>{}];
|
||||
EDataType* c_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
|
||||
Reference in New Issue
Block a user