diff --git a/CHANGELOG.md b/CHANGELOG.md index d9fad8c6d6..6229e0fd6b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -88,6 +88,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added Ping-pong scheduler support for GEMM operation along the K dimension. * Added rotating buffer feature for CK_Tile GEMM. * Added int8 support for CK_TILE GEMM. +* Added CK Tile Epilogue Chainer framework for composable epilogue sequences in GEMM operations ### Optimized diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp index cf29ee706e..9e2bc3e3fb 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp @@ -84,24 +84,51 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& const auto Run = [&](const auto memory_operation_) { constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + // Epilogue selection: set to true for chainer-based, false for standard + // CShuffleEpilogue + constexpr bool UseChainerEpilogue = true; + + using GemmEpilogue = std::conditional_t< + UseChainerEpilogue, + // Chainer-based epilogue + ck_tile::EpilogueChainer, + ck_tile::DefaultScheduleTag>>, + // Standard CShuffleEpilogue + ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>>; using Kernel = ck_tile::GemmKernelMultiD; auto kargs = Kernel::MakeKernelArgs(args); diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index b76528cbaa..2ddb96f620 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -16,6 +16,7 @@ #include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/tensor_shuffle_utils.hpp" #include "ck_tile/ops/gemm_quant.hpp" +#include "ck_tile/ops/epilogue.hpp" #include "gemm_utils.hpp" template , + + // Epilogue selection: use chainer for RowCol/Tensor quant, standard for others + // Toggle to switch between chainer-based and standard CShuffleEpilogue + constexpr bool UseChainerEpilogue = true; + + // Define the schedule tag based on quant mode + using ScheduleTag = + std::conditional_t>; + + using GemmEpilogue = std::conditional_t< + UseChainerEpilogue && (QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant), + // Chainer-based epilogue for RowCol/Tensor quant modes + ck_tile::EpilogueChainer, + typename TypeConfig::ADataType, + typename TypeConfig::BDataType>, + ck_tile::tuple<>, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c, + ck_tile::memory_operation_enum::set, + 1, + false, + 1, + TiledPermuteN>, + ScheduleTag>>, + // Standard CShuffleEpilogue for other modes + ck_tile::CShuffleEpilogue, - ck_tile::tuple<>, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - TiledPermuteN>>; + std::conditional_t< + std::is_same_v, + typename TypeConfig::ADataType, + typename TypeConfig::BDataType>, + ck_tile::tuple<>, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c, + ck_tile::memory_operation_enum::set, + 1, + false, + 1, + TiledPermuteN>>>; + using Kernel = ck_tile::QuantGemmKernel; diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 555402b53a..433462b22e 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -2,6 +2,10 @@ // SPDX-License-Identifier: MIT #pragma once +#include "ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp" +#include "ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp" +#include "ck_tile/ops/epilogue/chainer/cshuffle_epilogue_schedule.hpp" +#include "ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp" #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" diff --git a/include/ck_tile/ops/epilogue/chainer/README.md b/include/ck_tile/ops/epilogue/chainer/README.md new file mode 100644 index 0000000000..87581ebc9a --- /dev/null +++ b/include/ck_tile/ops/epilogue/chainer/README.md @@ -0,0 +1,61 @@ +# CK Tile Epilogue Chainer + +## Overview + +The Epilogue Chainer provides a modular epilogue processing framework through scheduler-defined operation graphs. + +## Architecture + +### Core Design Principle +The chainer follows a **Scheduler-Graph-Node** architecture with shared context: +- **Scheduler**: Defines operation graphs and creates a shared context +- **Graph**: Composes multiple operations into sequential processing units +- **Node**: Wraps individual epilogue operations with their arguments + +### EpilogueChainer +The `EpilogueChainer` struct serves as the modular epilogue processing facilitator. It delegates to schedulers for context creation and schedule generation, then processes the resulting operation graphs. + +### EpilogueNode +Individual epilogue operations are wrapped in `EpilogueNode` structures that capture required arguments at construction time and automatically forward them during processing. Supports both parameterized and parameter-free operations. + +### EpilogueGraph +The `EpilogueGraph` composes multiple nodes into sequential processing units that iterate over multiple accesses if needed, running all operations in order for each iteration. + +## Files + +### Core Infrastructure +- `epilogue_chainer.hpp` - General chainer, node, and graph infrastructure +- `common_epilogue_ops.hpp` - Epilogue operations usable with any epilogue type + +### CShuffle Implementation +- `cshuffle_epilogue_chainer_ops.hpp` - CShuffle-specific problem, context, and slice operations +- `cshuffle_epilogue_schedule.hpp` - CShuffle scheduler with pre-built schedules + +## Usage + +### Common Operations (common_epilogue_ops.hpp) +These operations work with any context that provides the standardized interface: +- `ScaleScalarOp` - Scale working-tile by scalar values +- `CastAndStoreToLdsOp` - Cast working-tile and store to LDS +- `LoadFromLdsOp` - Load output tile from LDS with sync +- `ElementwiseOp` - Apply elementwise operation with auxiliary tensors +- `StoreOp` - Store output tile to global memory +- `MoveWindowsOp` - Advance windows to next position + +### CShuffle-Specific Operations (cshuffle_epilogue_chainer_ops.hpp) +These operations are specific to CShuffle epilogue: +- `CShuffleSliceOp` - Slice accumulator tile based on distribution +- `CShuffleScaleWindowOp` - Scale using tensor windows with shuffle distribution + +### Context Interface +Operations communicate through a shared context with standardized members: +- `working_tile`: Tile for intermediate computations +- `out_tile`: Output tile +- `aux_windows`: Tuple of auxiliary tensor windows +- `lds_write_window`: Window for writing to LDS +- `lds_read_window`: Window for reading from LDS + +### Schedule Tags +- `DefaultScheduleTag` - Standard: Slice → CastStore → Load → ApplyD → Store → Move +- `RowColQuantScheduleTag` - With window scaling +- `TensorQuantScheduleTag` - With scalar scaling \ No newline at end of file diff --git a/include/ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp b/include/ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp new file mode 100644 index 0000000000..a0b2e7845c --- /dev/null +++ b/include/ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp @@ -0,0 +1,208 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" + +/// @file common_epilogue_ops.hpp +/// @brief Reusable simple epilogue operations which might be used to compose more complex one. +/// +/// +/// @par Context Interface +/// Operations expect the context to provide: +/// - working_tile: Tile for intermediate computations +/// - out_tile: Output tile for final results +/// - aux_windows: Tuple of auxiliary tensor windows (e.g., D tensors) +/// - lds_write_window: Window for writing to LDS (if using LDS) +/// - lds_read_window: Window for reading from LDS (if using LDS) +namespace ck_tile { + +/// @brief Scale working tile by scalar values +/// +/// @par Context Requirements +/// working_tile: Tile to scale +struct ScaleScalarOp +{ + template + CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window, + [[maybe_unused]] const AccTile& acc_tile, + [[maybe_unused]] const AuxWindows& aux_windows, + [[maybe_unused]] void* p_smem, + [[maybe_unused]] IAccess iAccess, + Context& context, + const ScaleA& scale_a, + const ScaleB& scale_b) + { + tile_elementwise_inout([&](auto& elem) { elem = elem * scale_a * scale_b; }, + context.working_tile); + } +}; + +/// @brief Cast working tile and store to LDS +/// +/// @tparam DataType Target data type for casting +/// +/// @par Context Requirements +/// working_tile: Tile to cast +/// lds_write_window: Window for writing to LDS +template +struct CastAndStoreToLdsOp +{ + template + CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window, + [[maybe_unused]] const AccTile& acc_tile, + [[maybe_unused]] const AuxWindows& aux_windows, + [[maybe_unused]] void* p_smem, + [[maybe_unused]] IAccess iAccess, + Context& context) + { + const auto casted_tile = cast_tile(context.working_tile); + store_tile(context.lds_write_window, casted_tile); + } +}; + +/// @brief Load output tile from LDS with synchronization +/// +/// @tparam TileEncodingPattern Pattern for tile distribution +/// +/// @par Context Requirements +/// lds_read_window: Window for reading from LDS +/// out_tile: Destination for loaded tile +template +struct LoadFromLdsOp +{ + template + CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window, + [[maybe_unused]] const AccTile& acc_tile, + [[maybe_unused]] const AuxWindows& aux_windows, + [[maybe_unused]] void* p_smem, + [[maybe_unused]] IAccess iAccess, + Context& context) + { + constexpr auto tile_distribution = TileEncodingPattern::make_2d_static_tile_distribution(); + block_sync_lds(); + context.out_tile = load_tile(make_tile_window(context.lds_read_window, tile_distribution)); + } +}; + +/// @brief Apply elementwise operation with auxiliary tensors +/// +/// @tparam Elementwise Elementwise functor type +/// @tparam NumAux Number of auxiliary tensors to load and apply +/// +/// @par Context Requirements +/// out_tile: In/out tile for elementwise operation +/// aux_windows: Tuple of auxiliary tensor windows +template +struct ElementwiseOp +{ + template + CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window, + [[maybe_unused]] const AccTile& acc_tile, + [[maybe_unused]] const AuxWindows& aux_windows, + [[maybe_unused]] void* p_smem, + [[maybe_unused]] IAccess iAccess, + Context& context) + { + const auto aux_tiles = generate_tuple( + [&](auto idx) { return load_tile(context.aux_windows[idx]); }, number{}); + + const auto tiles = concat_tuple_of_reference( + tie(context.out_tile, context.out_tile), + generate_tie([&](auto idx) -> const auto& { return aux_tiles[idx]; }, + number{})); + + tile_elementwise_inout_unpack(Elementwise{}, tiles); + } +}; + +/// @brief Store output tile to global memory +/// +/// @tparam MemOp Memory operation type (set or atomic_add) +/// +/// @par Context Requirements +/// out_tile: Tile to store +template +struct StoreOp +{ + template + CK_TILE_DEVICE void operator()(OutWindow& out_window, + [[maybe_unused]] const AccTile& acc_tile, + [[maybe_unused]] const AuxWindows& aux_windows, + [[maybe_unused]] void* p_smem, + [[maybe_unused]] IAccess iAccess, + Context& context) + { + if constexpr(MemOp == memory_operation_enum::set) + { + store_tile(out_window, context.out_tile); + } + else + { + update_tile(out_window, context.out_tile); + } + } +}; + +/// @brief Move output and auxiliary windows by step from space-filling curve +/// +/// @tparam SFC Space filling curve type providing step computation +/// @tparam NumAux Number of auxiliary windows to move +/// +/// @par Context Requirements +/// aux_windows: Tuple of windows to move +template +struct MoveWindowsOp +{ + template + CK_TILE_DEVICE void operator()(OutWindow& out_window, + [[maybe_unused]] const AccTile& acc_tile, + [[maybe_unused]] const AuxWindows& aux_windows, + [[maybe_unused]] void* p_smem, + IAccess iAccess, + Context& context) + { + constexpr index_t num_access = SFC::get_num_of_access(); + if constexpr(iAccess != num_access - 1) + { + constexpr auto step = SFC::get_forward_step(iAccess); + + move_tile_window(out_window, {step.at(number<0>{}), step.at(number<1>{})}); + + static_for<0, NumAux, 1>{}([&](auto idx) { + move_tile_window(context.aux_windows[idx], + {step.at(number<0>{}), step.at(number<1>{})}); + }); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp b/include/ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp new file mode 100644 index 0000000000..e8bd8c0c7d --- /dev/null +++ b/include/ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp @@ -0,0 +1,512 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +#include + +namespace ck_tile { + +//------------------------------------------------------------------------------ +// CShuffle-specific epilogue operations +// These operations are specific to CShuffle epilogue due to its unique. +//------------------------------------------------------------------------------ + +/// @brief Slice accumulator tile for CShuffle epilogue +/// +/// @par Purpose +/// Extracts a portion of the accumulator tile into the working tile +/// based on the current iteration index. This is CShuffle-specific. +/// +/// @tparam SFC Space filling curve type +/// @tparam CWarpDstr Warp distribution type +/// @tparam NumMXdlPerWavePerShuffle XDL tiles in M per wave per shuffle +/// @tparam NumNXdlPerWavePerShuffle XDL tiles in N per wave per shuffle +/// @tparam MPerIterShuffle M elements per shuffle iteration +/// @tparam NPerIterShuffle N elements per shuffle iteration +template +struct CShuffleSliceOp +{ + template + CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window, + const AccTile& acc_tile, + [[maybe_unused]] const AuxWindows& aux_windows, + [[maybe_unused]] void* p_smem, + IAccess iAccess, + Context& context) + { + constexpr auto idx_start = SFC::get_index(iAccess); + constexpr auto m_iter = number{}) / MPerIterShuffle>{}; + constexpr auto n_iter = number{}) / NPerIterShuffle>{}; + + constexpr auto warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto warp_y_index_zeros = uniform_sequence_gen_t{}; + + context.working_tile.get_thread_buffer() = acc_tile.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + warp_y_index_zeros), + merge_sequences(sequence{}, + warp_y_lengths)); + } +}; + +/// @brief Scale working tile using tensor windows (CShuffle-specific) +/// +/// @par Purpose +/// Scales the working tile using row and column scale tensors. +/// CShuffle-specific because it creates scale windows from the +/// working tile's distribution and handles window movement. +/// +/// @tparam SFC Space filling curve type +template +struct CShuffleScaleWindowOp +{ + template + CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window, + [[maybe_unused]] const AccTile& acc_tile, + [[maybe_unused]] const AuxWindows& aux_windows, + [[maybe_unused]] void* p_smem, + IAccess iAccess, + Context& context, + const ScaleRowTensor& scale_row_tensor, + const ScaleColTensor& scale_col_tensor) + { + auto scale_row_window = + make_tile_window(scale_row_tensor, context.working_tile.get_tile_distribution()); + auto scale_col_window = + make_tile_window(scale_col_tensor, context.working_tile.get_tile_distribution()); + + const auto scale_row_tile = load_tile(scale_row_window); + const auto scale_col_tile = load_tile(scale_col_window); + + tile_elementwise_inout(element_wise::MultiDMultiply{}, + context.working_tile, + context.working_tile, + scale_row_tile, + scale_col_tile); + + constexpr index_t num_access = SFC::get_num_of_access(); + if constexpr(iAccess != num_access - 1) + { + constexpr auto step = SFC::get_forward_step(number{}); + move_tile_window(scale_row_window, {step.at(number<0>{}), step.at(number<1>{})}); + move_tile_window(scale_col_window, {step.at(number<0>{}), step.at(number<1>{})}); + } + } +}; + +//------------------------------------------------------------------------------ +// CShuffle problem and base operation definitions +//------------------------------------------------------------------------------ + +/// @brief Problem configuration for CShuffle epilogue chainer operations +/// @note Mirrors CShuffleEpilogueProblem but uses AsDataType/BsDataType for tuple support +template +struct CShuffleEpilogueChainProblem +{ + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; + static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size(); + static constexpr index_t kMPerBlock = kM_; + static constexpr index_t kNPerBlock = kN_; + static constexpr index_t MWave = MWave_; + static constexpr index_t NWave = NWave_; + static constexpr index_t MPerXdl = MPerXdl_; + static constexpr index_t NPerXdl = NPerXdl_; + static constexpr index_t KPerXdl = KPerXdl_; + static constexpr index_t isCTransposed = isCTransposed_; + static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; + static constexpr bool FixedVectorSize = FixedVectorSize_; + static constexpr index_t VectorSizeC = VectorSizeC_; + static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_; + static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_; + static constexpr index_t kNumWaveGroups = kNumWaveGroups_; + static constexpr index_t NumDTensor = DsDataType::size(); + + static_assert(NumDTensor == DsLayout::size(), + "The size of DsDataType and DsLayout should be the same"); +}; + +template +struct CShuffleEpilogueChainBaseOp +{ + using Problem = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + + static constexpr bool ADataTypeIsTuple = is_detected::value; + static constexpr bool BDataTypeIsTuple = is_detected::value; + + using AsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using BsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; + using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; + + using ATypeToUse = + std::conditional_t, BDataType, ADataType>; + // Used for weight-only quantization kernel, B would be dequantized to the same data type as A + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; + static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kMPerBlock = Problem::kMPerBlock; + static constexpr index_t kNPerBlock = Problem::kNPerBlock; + static constexpr index_t MWave = Problem::MWave; + static constexpr index_t NWave = Problem::NWave; + static constexpr index_t MPerXdl = Problem::MPerXdl; + static constexpr index_t NPerXdl = Problem::NPerXdl; + static constexpr index_t KPerXdl = Problem::KPerXdl; + static constexpr index_t isCTransposed = Problem::isCTransposed; + static constexpr bool FixedVectorSize = Problem::FixedVectorSize; + static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; + static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp; + static constexpr index_t VectorSizeC = Problem::VectorSizeC; + static constexpr index_t MPerIteration = MPerXdl * MWave; + static constexpr index_t NPerIteration = NPerXdl * NWave; + static constexpr index_t NumDTensor = Problem::NumDTensor; + + static_assert(NumDTensor == DsLayout::size(), + "The size of DsDataType and DsLayout should be the same"); + + /** + * @brief Get the vector store size for C tensor. + * + * @note The vector store size for output C tensor would depend on multiple factors + * like its data layout and warp gemm C transposition. In general it would + * be the number of consecutive elements in contiguous C dimension hold by + * single thread. + * + * @return The vector store size for C tensor. + */ + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC() + { + if constexpr(FixedVectorSize) + { + return VectorSizeC; + } + constexpr index_t max_vector_size = 16; + if constexpr(std::is_same_v) + { + return std::min(static_cast(NPerIteration), + static_cast(max_vector_size / sizeof(ODataType))); + } + else if constexpr(std::is_same_v) + { + return std::min(static_cast(MPerIteration), + static_cast(max_vector_size / sizeof(ODataType))); + } + else + { + static_assert(false, "Unsupported ELayout!"); + } + } + + /** + * @brief Get the vector store size for Di tensor. + * + * @return The vector store size for Di tensor. + */ + template + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number index) + { + constexpr index_t max_vector_size = 16; + using DiDataType = remove_cvref_t>; + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return std::min(static_cast(NPerIteration), + static_cast(max_vector_size / sizeof(DiDataType))); + } + else if constexpr(std::is_same_v) + { + return std::min(static_cast(MPerIteration), + static_cast(max_vector_size / sizeof(DiDataType))); + } + else + { + static_assert(false, "Unsupported DLayout!"); + } + } + /** + * @brief Shuffle tile configuration parameters + * + * @details These parameters control the number of XDL tiles processed per wave in each shuffle + * iteration: + * - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave + * - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave + */ + static constexpr auto shuffle_tile_tuple = [] { + constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size(); + if constexpr(elem_per_thread >= GetVectorSizeC()) + { + return std::make_tuple(1, 1); + } + else + { + constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread; + if constexpr(std::is_same_v) + { + static_assert((kMPerBlock % (MPerXdl * MWave) == 0) && + (kMPerBlock % num_xdl_shuffles == 0), + "kMPerBlock must be divisible by MPerXdl*MWave and " + "num_xdl_shuffles for CShuffleEpilogueStageBase"); + return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1); + } + else + { + static_assert((kNPerBlock % (NPerXdl * NWave) == 0) && + (kNPerBlock % num_xdl_shuffles == 0), + "kNPerBlock must be divisible by NPerXdl*NWave and " + "num_xdl_shuffles for CShuffleEpilogueStageBase"); + return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave))); + } + } + }(); + static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple); + static constexpr index_t NumNXdlPerWavePerShuffle = + max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple)); + + static constexpr auto MNPerIterationShuffle = [] { + constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle; + constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle; + if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0) + return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave); + else + return std::make_tuple(m_val, n_val); + }(); + static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle); + static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle); + + using WG = WarpGemmDispatcher; + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + using CWarpDstrEncoding = typename WG::CWarpDstrEncoding; + using SFC = space_filling_curve, + sequence<0, 1>, + sequence>; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor() + { + // N is contiguous dimension + if constexpr(std::is_same_v) + { + return make_naive_tensor_descriptor( + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{})); + } + // M is contiguous dimension + else if constexpr(std::is_same_v) + { + return make_naive_tensor_descriptor( + make_tuple(number{}, number{}), + make_tuple(number<1>{}, number{})); + } + else + { + static_assert(false, "Unsupported ELayout!"); + } + } + + CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode() + { + constexpr auto block_outer_dstr_encoding = [] { + if constexpr(BlockedXDLN_PerWarp == 1) + { + return tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + } + else + { + constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp; + // BlockedLayout + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2, 2>, + sequence<0, 0, 2>>{}; + } + }(); + constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding( + block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{}); + + return block_dstr_encoding; + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType); + } + + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + + /// @brief Context structure for CShuffle epilogue operations + /// + /// @par Purpose + /// The context serves as a shared workspace that maintains intermediate results + /// and resources across multiple epilogue operations. It eliminates the need for + /// operations to recreate shared data structures and enables data flow + /// through the operation graph. + /// + /// @par Standardized Interface + /// Uses standardized member names so common operations can work with this context: + /// - working_tile: Intermediate tile for shuffle operations + /// - out_tile: Output tile for final results + /// - aux_windows: Auxiliary tensor windows (D tensors) + /// - lds_write_window: Window for writing to LDS + /// - lds_read_window: Window for reading from LDS + template + struct CShuffleContext + { + WorkingTileType working_tile; // Working tile for shuffle operations + LdsBlockType lds_block; // LDS block view + LdsWriteWindowType lds_write_window; // Window for writing to LDS + LdsReadWindowType lds_read_window; // Window for reading from LDS + AuxWindowsType aux_windows; // Auxiliary tensor windows (D tensors) + OutTileType out_tile; // Output tile + }; + + template + CK_TILE_DEVICE auto operator()([[maybe_unused]] OutDramWindow& out_window, + [[maybe_unused]] const AccTile& acc_tile, + const DsDramWindows& ds_windows, + void* p_smem) + { + static_assert( + std::is_same_v, + "Currently, the CShuffleEpilogueStageBase only supports the Row Major Output layout"); + + constexpr auto working_tile_distr = + make_static_tile_distribution(MakeLdsDistributionEncode()); + auto working_tile = make_static_distributed_tensor(working_tile_distr); + + constexpr auto lds_block_desc = MakeLdsBlockDescriptor(); + auto lds_block = make_tensor_view(static_cast(p_smem), + lds_block_desc); + + auto lds_write_window = make_tile_window( + lds_block, + make_tuple(number{}, number{}), + {0, 0}, + working_tile_distr); + + auto lds_read_window = make_tile_window( + lds_block, + make_tuple(number{}, number{}), + {0, 0}); + + constexpr auto dram_tile_distribution = + TileEncodingPattern::make_2d_static_tile_distribution(); + auto aux_windows = generate_tuple( + [&](auto idx) { return make_tile_window(ds_windows[idx], dram_tile_distribution); }, + number{}); + + auto out_tile = load_tile(make_tile_window(lds_read_window, dram_tile_distribution)); + + using ContextType = CShuffleContext; + + ContextType context; + context.working_tile = working_tile; + context.lds_block = lds_block; + context.lds_write_window = lds_write_window; + context.lds_read_window = lds_read_window; + context.aux_windows = aux_windows; + context.out_tile = out_tile; + + return context; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/chainer/cshuffle_epilogue_schedule.hpp b/include/ck_tile/ops/epilogue/chainer/cshuffle_epilogue_schedule.hpp new file mode 100644 index 0000000000..683bfe7377 --- /dev/null +++ b/include/ck_tile/ops/epilogue/chainer/cshuffle_epilogue_schedule.hpp @@ -0,0 +1,129 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp" +#include "ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp" +#include "ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp" + +namespace ck_tile { + +/// @brief Schedule type tags for epilogue selection +/// @par Purpose +/// Each tag corresponds to a pre-built schedule, these are used to select a schedule + +/// Standard epilogue schedule: Slice → CastStore → Load → ApplyD → Store → Move +struct DefaultScheduleTag +{ +}; + +/// RowCol quantization schedule: Slice → ScaleWindow → CastStore → Load → ApplyD → Store → Move +struct RowColQuantScheduleTag +{ +}; + +/// Tensor quantization schedule: Slice → ScaleScalar → CastStore → Load → ApplyD → Store → Move +struct TensorQuantScheduleTag +{ +}; + +/// @brief CShuffle epilogue scheduler providing pre-built schedules +/// +/// @par Overview +/// CshuffleEpilogueSchedule acts as the scheduler component for EpilogueChainer. +/// It provides context creation and pre-built schedules. The scheduler +/// uses tags to select/create appropriate epilogue schedule. +/// +/// @tparam Problem The epilogue problem configuration +/// @tparam ScheduleTag Tag selecting the epilogue schedule +template +struct CshuffleEpilogueSchedule +{ + using ProblemType = Problem; + using BaseOp = CShuffleEpilogueChainBaseOp; + + static constexpr index_t NumAccess = BaseOp::SFC::get_num_of_access(); + + /// @brief Create context for epilogue operations + template + CK_TILE_DEVICE static auto create_context(OutWindow& out_window, + const AccTile& acc_tile, + const AuxWindows& aux_windows, + void* p_smem) + { + return BaseOp{}(out_window, acc_tile, aux_windows, p_smem); + } + + /// @brief Make schedule based on compile-time tag selection + template + CK_TILE_DEVICE static auto make_schedule(Args&&... args) + { + if constexpr(std::is_same_v) + { + // Standard epilogue + // Schedule: Slice -> CastAndStoreLds -> Load -> ApplyD -> Store -> MoveWindows + static_assert(sizeof...(args) == 0, "DefaultSchedule expects no arguments"); + return make_graph( + make_node>(), + make_node>(), + make_node>(), + make_node>(), + make_node>(), + make_node>()); + } + else if constexpr(std::is_same_v) + { + // RowCol quantization schedule with tensor windows + // Schedule: Slice -> ScaleWindow -> CastAndStoreLds -> Load -> ApplyD -> Store -> + // MoveWindows + static_assert(sizeof...(args) == 2, + "RowColQuantSchedule requires exactly 2 scale tensor arguments"); + return make_graph( + make_node>(), + make_node>(std::forward(args)...), + make_node>(), + make_node>(), + make_node>(), + make_node>(), + make_node>()); + } + else if constexpr(std::is_same_v) + { + // Tensor quantization schedule with scalar values + // Schedule: Slice -> ScaleScalar -> CastAndStoreLds -> Load -> ApplyD -> Store -> + // MoveWindows + static_assert(sizeof...(args) == 2, + "TensorQuantSchedule requires exactly 2 scalar arguments"); + return make_graph( + make_node>(), + make_node(std::forward(args)...), + make_node>(), + make_node>(), + make_node>(), + make_node>(), + make_node>()); + } + else + { + static_assert(false, "Unknown schedule tag"); + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp b/include/ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp new file mode 100644 index 0000000000..25ef000cc3 --- /dev/null +++ b/include/ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp @@ -0,0 +1,213 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" + +namespace ck_tile { + +/// @brief Epilogue Chainer - Modular epilogue processing facilitator +/// +/// @par Overview +/// EpilogueChainer provides an interface for processing epilogue operations +/// through schedules. The chainer uses decomposed epilogue operations, these are +/// scheduled/sequenced by a Scheduler to form operation graphs. +/// +/// @tparam Scheduler The schedule provider that defines epilogue operation graphs +template +struct EpilogueChainer +{ + using Problem = typename Scheduler::ProblemType; + using BaseOp = typename Scheduler::BaseOp; + + using ODataType = typename BaseOp::ODataType; + using DsDataType = typename BaseOp::DsDataType; + using DsLayout = typename BaseOp::DsLayout; + using AccDataType = typename BaseOp::AccDataType; + static constexpr auto MemoryOperation = BaseOp::MemoryOperation; + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return BaseOp::GetSmemSize(); } + + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC() + { + return BaseOp::GetVectorSizeC(); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number idx) + { + return BaseOp::GetVectorSizeD(idx); + } + + CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode() + { + return BaseOp::MakeLdsDistributionEncode(); + } + + /// @brief Process epilogue through scheduler-defined operation graph + /// + /// @par Flow + /// 1. Create shared context through scheduler + /// 2. Generate operation schedule based on arguments + /// 3. Run scheduled operations in sequence + template + CK_TILE_DEVICE void operator()(OutWindow& out_window, + const AccTile& acc_tile, + const AuxWindows& aux_windows, + void* p_smem, + Args&&... args) const + { + // The context serves as a shared workspace that maintains intermediate results + // and resources across multiple epilogue operations. + auto context = Scheduler::create_context(out_window, acc_tile, aux_windows, p_smem); + auto schedule = Scheduler::make_schedule(std::forward(args)...); + schedule(out_window, acc_tile, aux_windows, p_smem, context); + } +}; + +/// @brief Epilogue operation wrapper with arguments +/// +/// @par Purpose +/// EpilogueNode wraps individual epilogue operations with their required arguments, +/// allowing them to be composed into operation graphs. Arguments are captured at construction +/// time and automatically forwarded during processing. +/// +/// @tparam EpilogueType Epilogue operation (e.g., SliceEpilogue, ScaleEpilogue) +/// @tparam Args Types of arguments required by the epilogue operation +template +struct EpilogueNode +{ + using Epilogue = EpilogueType; + ck_tile::tuple args; + + constexpr EpilogueNode(Args... a) : args(a...) {} + + /// @brief Process epilogue without iteration index + template + CK_TILE_DEVICE void operator()(OutWindow& out_window, + const AccTile& acc_tile, + const AuxWindows& aux_windows, + void* p_smem, + Context& context) const + { + ck_tile::apply( + [&](auto&&... epilogue_args) { + EpilogueType{}(out_window, + acc_tile, + aux_windows, + p_smem, + context, + std::forward(epilogue_args)...); + }, + args); + } + + /// @brief Process epilogue with iteration index + template + CK_TILE_DEVICE void operator()(OutWindow& out_window, + const AccTile& acc_tile, + const AuxWindows& aux_windows, + void* p_smem, + Context& context, + number iAccess) const + { + ck_tile::apply( + [&](auto&&... epilogue_args) { + EpilogueType{}(out_window, + acc_tile, + aux_windows, + p_smem, + iAccess, + context, + std::forward(epilogue_args)...); + }, + args); + } +}; + +/// @brief Specialization for epilogue operation wrapper with no arguments +template +struct EpilogueNode +{ + using Epilogue = EpilogueType; + ck_tile::tuple<> args; + + constexpr EpilogueNode() = default; + + template + CK_TILE_DEVICE void operator()(OutWindow& out_window, + const AccTile& acc_tile, + const AuxWindows& aux_windows, + void* p_smem, + Context& context) const + { + EpilogueType{}(out_window, acc_tile, aux_windows, p_smem, context); + } + + template + CK_TILE_DEVICE void operator()(OutWindow& out_window, + const AccTile& acc_tile, + const AuxWindows& aux_windows, + void* p_smem, + Context& context, + number iAccess) const + { + EpilogueType{}(out_window, acc_tile, aux_windows, p_smem, iAccess, context); + } +}; + +/// @brief Operation graph that sequentially composes multiple epilogue operations +/// +/// @tparam Steps Number of iterations +/// @tparam EpilogueTypes Types of epilogue nodes in the operation graph +template +struct EpilogueGraph +{ + ck_tile::tuple epilogues; + + constexpr EpilogueGraph() = default; + constexpr EpilogueGraph(EpilogueTypes... eps) : epilogues(eps...) {} + + /// @brief Process all epilogues for each iteration in sequence + template + CK_TILE_DEVICE void operator()(OutWindow& out_window, + const AccTile& acc_tile, + const AuxWindows& aux_windows, + void* p_smem, + Context& context) const + { + // For each iteration, process all epilogues in order + static_for<0, Steps, 1>{}([&](auto iAccess) { + static_for<0, sizeof...(EpilogueTypes), 1>{}([&](auto I) { + epilogues.template get()( + out_window, acc_tile, aux_windows, p_smem, context, iAccess); + }); + }); + } +}; + +/// @brief Helper function for creating epilogue nodes +template +constexpr auto make_node(Args... args) +{ + return EpilogueNode{args...}; +} + +/// @brief Helper function for creating operation graphs +template +constexpr auto make_graph(EpilogueTypes... epilogues) +{ + return EpilogueGraph{epilogues...}; +} + +} // namespace ck_tile