mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
Merge flatmm Operator with universal gemm (#2434)
* Initial commit * Adding new tile partitioner to flatmm * intermediate changes * debugging kernels * Updating flatmm example to universal gemm example * updated flatmm kernel to run via gemmKernel * update universal gemm to incorporate flatmm * debug * Fix flatmm call * Fixing other kernels and tests for API changes * clang formatted * fixing gemm tests * added test for flatmm and simplify kernel arguments * adding flatmm test * fix test for flatmm * simplify gemm kernel with flatmm * remove flatmm related files * addressing review comments and code clean up * resolving empty file * resolving empty file * clang formatted * addressing review comments * enable persistent kernel for flatmm * reverted the removed files for flatmm * reverted the removed files for flatmm * changed flatmm to weightPReshuffle; removed the _1 added in teh faltmm example * some more renames * clang formatted
This commit is contained in:
@@ -59,7 +59,8 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
Persistent,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
@@ -71,7 +72,6 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run =
|
||||
@@ -92,6 +92,7 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -101,7 +102,7 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
UniversalGemmProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
@@ -112,6 +113,7 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
@@ -135,7 +137,7 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
@@ -214,8 +216,21 @@ template <typename GemmConfig,
|
||||
typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
bool preshuffle = GemmConfig::Preshuffle;
|
||||
|
||||
if(preshuffle && std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
throw std::runtime_error("Preshuffle is not supported for this int4 datatype!");
|
||||
}
|
||||
|
||||
if(preshuffle && a_layout != "R" && b_layout != "C")
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user