mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)
* initial poc * factor out common parts in operator() * cv4 * rest of the universal gemm pipelines * fix test * remove boilerplate from tile engine * fix example * fix example * format * fix tests build for gemm * remove base pipeline codegen from gemm instance builder * unify v3 logic with the rest of universal gemm pipelines * fix build for multi abd test * fix test gemm multi d * fix build for weight preshuffle * fix grouped gemm test * fix grouped gemm multi d test * fix grouped gemm preshuffle * fix grouped gemm example except for quant * fix gemm preshuffle * fix splitk 2 stage example * fix batched gemm example * fix multid example * fix multiabd example * fix batched gemm test * fixup * fix examples build * fix grouped gemm test build * fix smoke builder * hacky poc * fix tile engine * kill the lambda * maybe fix test build * more fixes * clang-format * save temp * clang-format * mostly fix examples * clang-format * remove dead code * more cleanup * fix fmha bwd build (default epilogue set/add appears to be broken) * fix default epilogue tests but not correctness * clang-format * fix bquant * clang-format * cleanup dead code * rearrange make windows for readability * restore changes to IsSupportedArgument * fix smoke-builder * clang-format * fixup rename class * build fixes * clang-format * fix builder * fixup * remove set from builder tests * fix test * clang-format * re-refactor the kernels * clang-format * fix header license * remove memory operation from conv bwd test * clang-format * clang-format example,include * clang-format test * build fixes * clang-format * solve compilation error * fix the CI * solve compilation error * clang format * solve merge conflict * solve merge conflict * solve the gfx11 error * solve test error * moar build fixes * remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -719,8 +719,8 @@ struct SelectedKernel {{
|
||||
elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
|
||||
instance_code += f"""
|
||||
|
||||
// Kernel type
|
||||
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
// Kernel type
|
||||
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
// Kernel arguments
|
||||
auto kargs = GemmKernel::MakeKernelArgs(args);
|
||||
@@ -802,8 +802,8 @@ struct SelectedKernel {{
|
||||
ck_tile::tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock, // kM_
|
||||
TilePartitioner::NPerBlock, // kN_
|
||||
TileM, // kM_
|
||||
TileN, // kN_
|
||||
WarpPerBlock_M, // MWave_
|
||||
WarpPerBlock_N, // NWave_
|
||||
WarpTileM, // MPerXdl_
|
||||
|
||||
@@ -481,8 +481,6 @@ struct SelectedKernel {{
|
||||
GemmUniversalTraits>;
|
||||
|
||||
static float launch(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {{
|
||||
const auto Run = [&](const auto memory_operation_) {{
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
@@ -512,7 +510,6 @@ struct SelectedKernel {{
|
||||
WarpTileN, // NPerXdl_
|
||||
WarpTileK, // KPerXdl_
|
||||
TransposeC, // isCTransposed_
|
||||
memory_operation, // MemoryOperation_
|
||||
NumWaveGroups>; // kNumWaveGroups_
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;
|
||||
@@ -558,30 +555,12 @@ struct SelectedKernel {{
|
||||
workspace_data.SetZero();
|
||||
}}
|
||||
}};
|
||||
|
||||
|
||||
|
||||
// Launch kernel
|
||||
float ave_time = ck_tile::launch_kernel_time_mask(
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
stream,
|
||||
reset_data_buffers,
|
||||
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
|
||||
// ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
// return std::make_tuple(ave_time, num_wgs_per_tile);
|
||||
}};
|
||||
|
||||
|
||||
if constexpr(ck_tile::StreamKReductionStrategy::Atomic == reduction_strategy)
|
||||
{{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{{}});
|
||||
}}
|
||||
else // We are using ck_tile::StreamKReductionStrategy::Reduction
|
||||
{{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{{}});
|
||||
}}
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user