mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)
* initial poc * factor out common parts in operator() * cv4 * rest of the universal gemm pipelines * fix test * remove boilerplate from tile engine * fix example * fix example * format * fix tests build for gemm * remove base pipeline codegen from gemm instance builder * unify v3 logic with the rest of universal gemm pipelines * fix build for multi abd test * fix test gemm multi d * fix build for weight preshuffle * fix grouped gemm test * fix grouped gemm multi d test * fix grouped gemm preshuffle * fix grouped gemm example except for quant * fix gemm preshuffle * fix splitk 2 stage example * fix batched gemm example * fix multid example * fix multiabd example * fix batched gemm test * fixup * fix examples build * fix grouped gemm test build * fix smoke builder * hacky poc * fix tile engine * kill the lambda * maybe fix test build * more fixes * clang-format * save temp * clang-format * mostly fix examples * clang-format * remove dead code * more cleanup * fix fmha bwd build (default epilogue set/add appears to be broken) * fix default epilogue tests but not correctness * clang-format * fix bquant * clang-format * cleanup dead code * rearrange make windows for readability * restore changes to IsSupportedArgument * fix smoke-builder * clang-format * fixup rename class * build fixes * clang-format * fix builder * fixup * remove set from builder tests * fix test * clang-format * re-refactor the kernels * clang-format * fix header license * remove memory operation from conv bwd test * clang-format * clang-format example,include * clang-format test * build fixes * clang-format * solve compilation error * fix the CI * solve compilation error * clang format * solve merge conflict * solve merge conflict * solve the gfx11 error * solve test error * moar build fixes * remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -222,19 +222,13 @@ struct StreamKKernel
|
||||
const index_t block_idx_n,
|
||||
const index_t k_size)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
|
||||
|
||||
const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
|
||||
// Create block windows using specialized methods
|
||||
const auto& as_block_window =
|
||||
UniversalGemmKernel::MakeABlockWindows(as_ptr, kargs, k_size, block_idx_m);
|
||||
const auto& bs_block_window =
|
||||
UniversalGemmKernel::MakeBBlockWindows(bs_ptr, kargs, k_size, block_idx_n);
|
||||
const auto& ds_block_window =
|
||||
UniversalGemmKernel::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
// Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
|
||||
// has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
|
||||
@@ -243,6 +237,7 @@ struct StreamKKernel
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
|
||||
bs_block_window[UniversalGemmKernel::I0],
|
||||
num_loop,
|
||||
@@ -253,7 +248,9 @@ struct StreamKKernel
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
auto c_block_window =
|
||||
UniversalGemmKernel::template MakeCBlockWindows<TilePartitioner::MemoryOperation>(
|
||||
c_ptr, kargs, block_idx_m, block_idx_n);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
@@ -525,21 +522,13 @@ struct StreamKKernel
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
|
||||
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
UniversalGemmKernel::template MakeGemmTensorViews<
|
||||
EpiloguePipeline::MemoryOperation>(
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, k_size);
|
||||
|
||||
const auto& gemm_pad_views =
|
||||
UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, i_m, i_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
|
||||
// Create block windows using specialized methods
|
||||
const auto& as_block_window =
|
||||
UniversalGemmKernel::MakeABlockWindows({a_ptr}, kargs, k_size, i_m);
|
||||
const auto& bs_block_window =
|
||||
UniversalGemmKernel::MakeBBlockWindows({b_ptr}, kargs, k_size, i_n);
|
||||
const auto& ds_block_window =
|
||||
UniversalGemmKernel::MakeDBlockWindows({/*ds_ptr*/}, kargs, i_m, i_n);
|
||||
|
||||
// Since num_loop can vary per WG and per iteration of the Stream-K while loop,
|
||||
// we compute has_hot_loop and tail_num here. This is a similar pattern used by
|
||||
@@ -548,6 +537,7 @@ struct StreamKKernel
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
|
||||
bs_block_window[UniversalGemmKernel::I0],
|
||||
num_loop_sk,
|
||||
@@ -594,7 +584,8 @@ struct StreamKKernel
|
||||
}
|
||||
}
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows<
|
||||
TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
@@ -617,7 +608,8 @@ struct StreamKKernel
|
||||
// tensor.
|
||||
if(tile_started && !partner_in_tile)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows<
|
||||
TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
break;
|
||||
|
||||
@@ -27,6 +27,9 @@ struct StreamKTilePartitionerBase
|
||||
static constexpr index_t NPerBlock = BlockGemmShapeType::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShapeType::kK;
|
||||
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType;
|
||||
static constexpr auto MemoryOperation = (ReductionStrategy == StreamKReductionStrategy::Atomic)
|
||||
? memory_operation_enum::atomic_add
|
||||
: memory_operation_enum::set;
|
||||
|
||||
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user