mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 06:44:36 +00:00
Fix the CK Tile related operators (#2356)
* fix the flatmm
* Fix the pipeline
* address the comment
[ROCm/composable_kernel commit: 3c4cdfac4f]
This commit is contained in:
@@ -69,9 +69,12 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
|
||||
@@ -166,7 +166,7 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
|
||||
@@ -49,9 +49,12 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
|
||||
Reference in New Issue
Block a user