[CK_TILE] Epilogue chaining (Lwpck 3373) (#2773)

* Epilogue chainer

* epilogue chainer with context to share state in between epilogues
* chain-able epilogues for cshuffle

* clang-format

* rebase related changes

- Added separate chainer test
-  clang format

* comment resolutions

* clang-format

* Policy based chaining

- basic Policy structure to control blanket looping and barrier
placement.

- to be extended for fine grianed control

- to  be modified to move possible auto-compute values and SFC  access
count to policy

* Refactoring as per spec

- Introduced epilogue schedule, graph
- modified chainer to function with graph and schedule

* minor_changes

- made functions to overload in the epilogue_graph file

* clang-format

* Documentation and Comments

- Added comments to files
- Noted changes in changelog
- Added README to explain the chainer and current status, exact use
steps to be added

* Comment resolutions

- README modified with the suggested changes
- Comment fixed accordingly

* major refactoring

- modified the chainer files to match the new design
- updated comments
- updated readme
- multi-d example shocases use of the chainer

* minor cleanup

* tensor and rowcol quant chainer epilogue

- added scalarepilogue for tensor quant
- added schedule for tensorquant
- modified quant example to use chainer and appropriate schedules

* Refactor epilogue chainer: generalize ops and standardize context interface

Address review comments.

Changes:
- Rename CastToLdsOp to CastAndStoreToLdsOp for clarity
- Standardize context member names (working_tile, out_tile, aux_windows)
- Update README documentation with correct operation names
- Clean up parameter naming in epilogue_chainer.hpp (OutWindow, AccTile,
AuxWindows)
- common_epilogue_ops.hpp: General-purpose ops (ScaleScalarOp,
CastAndStoreToLdsOp,
  LoadFromLdsOp, ElementwiseOp, StoreOp, MoveWindowsOp)
- cshuffle_epilogue_chainer_ops.hpp: CShuffle-specific context and slice
operations
- epilogue_chainer.hpp: Cleaned up parameter naming for generality
- Removed test files that are no longer needed. These were added for
intermediate use

* update cshuffle chainer ops file w.r.t cshuffle_epilogue.hpp updates & add chainer to quant gemm example

* fix compile errors

- CI uses c++17 while the code had c++20 features

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
Yashvardhan Agarwal
2025-12-18 11:02:02 +02:00
committed by GitHub
parent bfac64953f
commit 15e81397a4
9 changed files with 1244 additions and 42 deletions

View File

@@ -2,6 +2,10 @@
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp"
#include "ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp"
#include "ck_tile/ops/epilogue/chainer/cshuffle_epilogue_schedule.hpp"
#include "ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"

View File

@@ -0,0 +1,61 @@
# CK Tile Epilogue Chainer
## Overview
The Epilogue Chainer provides a modular epilogue processing framework through scheduler-defined operation graphs.
## Architecture
### Core Design Principle
The chainer follows a **Scheduler-Graph-Node** architecture with shared context:
- **Scheduler**: Defines operation graphs and creates a shared context
- **Graph**: Composes multiple operations into sequential processing units
- **Node**: Wraps individual epilogue operations with their arguments
### EpilogueChainer
The `EpilogueChainer` struct serves as the modular epilogue processing facilitator. It delegates to schedulers for context creation and schedule generation, then processes the resulting operation graphs.
### EpilogueNode
Individual epilogue operations are wrapped in `EpilogueNode` structures that capture required arguments at construction time and automatically forward them during processing. Supports both parameterized and parameter-free operations.
### EpilogueGraph
The `EpilogueGraph` composes multiple nodes into sequential processing units that iterate over multiple accesses if needed, running all operations in order for each iteration.
## Files
### Core Infrastructure
- `epilogue_chainer.hpp` - General chainer, node, and graph infrastructure
- `common_epilogue_ops.hpp` - Epilogue operations usable with any epilogue type
### CShuffle Implementation
- `cshuffle_epilogue_chainer_ops.hpp` - CShuffle-specific problem, context, and slice operations
- `cshuffle_epilogue_schedule.hpp` - CShuffle scheduler with pre-built schedules
## Usage
### Common Operations (common_epilogue_ops.hpp)
These operations work with any context that provides the standardized interface:
- `ScaleScalarOp` - Scale working-tile by scalar values
- `CastAndStoreToLdsOp<DstType>` - Cast working-tile and store to LDS
- `LoadFromLdsOp<Pattern>` - Load output tile from LDS with sync
- `ElementwiseOp<Func, NumAux>` - Apply elementwise operation with auxiliary tensors
- `StoreOp<MemOp>` - Store output tile to global memory
- `MoveWindowsOp<SFC, NumAux>` - Advance windows to next position
### CShuffle-Specific Operations (cshuffle_epilogue_chainer_ops.hpp)
These operations are specific to CShuffle epilogue:
- `CShuffleSliceOp` - Slice accumulator tile based on distribution
- `CShuffleScaleWindowOp` - Scale using tensor windows with shuffle distribution
### Context Interface
Operations communicate through a shared context with standardized members:
- `working_tile`: Tile for intermediate computations
- `out_tile`: Output tile
- `aux_windows`: Tuple of auxiliary tensor windows
- `lds_write_window`: Window for writing to LDS
- `lds_read_window`: Window for reading from LDS
### Schedule Tags
- `DefaultScheduleTag` - Standard: Slice → CastStore → Load → ApplyD → Store → Move
- `RowColQuantScheduleTag` - With window scaling
- `TensorQuantScheduleTag` - With scalar scaling

View File

@@ -0,0 +1,208 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
/// @file common_epilogue_ops.hpp
/// @brief Reusable simple epilogue operations which might be used to compose more complex one.
///
///
/// @par Context Interface
/// Operations expect the context to provide:
/// - working_tile: Tile for intermediate computations
/// - out_tile: Output tile for final results
/// - aux_windows: Tuple of auxiliary tensor windows (e.g., D tensors)
/// - lds_write_window: Window for writing to LDS (if using LDS)
/// - lds_read_window: Window for reading from LDS (if using LDS)
namespace ck_tile {
/// @brief Scale working tile by scalar values
///
/// @par Context Requirements
/// working_tile: Tile to scale
struct ScaleScalarOp
{
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename IAccess,
typename Context,
typename ScaleA,
typename ScaleB>
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
[[maybe_unused]] const AccTile& acc_tile,
[[maybe_unused]] const AuxWindows& aux_windows,
[[maybe_unused]] void* p_smem,
[[maybe_unused]] IAccess iAccess,
Context& context,
const ScaleA& scale_a,
const ScaleB& scale_b)
{
tile_elementwise_inout([&](auto& elem) { elem = elem * scale_a * scale_b; },
context.working_tile);
}
};
/// @brief Cast working tile and store to LDS
///
/// @tparam DataType Target data type for casting
///
/// @par Context Requirements
/// working_tile: Tile to cast
/// lds_write_window: Window for writing to LDS
template <typename DataType>
struct CastAndStoreToLdsOp
{
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename IAccess,
typename Context>
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
[[maybe_unused]] const AccTile& acc_tile,
[[maybe_unused]] const AuxWindows& aux_windows,
[[maybe_unused]] void* p_smem,
[[maybe_unused]] IAccess iAccess,
Context& context)
{
const auto casted_tile = cast_tile<DataType>(context.working_tile);
store_tile(context.lds_write_window, casted_tile);
}
};
/// @brief Load output tile from LDS with synchronization
///
/// @tparam TileEncodingPattern Pattern for tile distribution
///
/// @par Context Requirements
/// lds_read_window: Window for reading from LDS
/// out_tile: Destination for loaded tile
template <typename TileEncodingPattern>
struct LoadFromLdsOp
{
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename IAccess,
typename Context>
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
[[maybe_unused]] const AccTile& acc_tile,
[[maybe_unused]] const AuxWindows& aux_windows,
[[maybe_unused]] void* p_smem,
[[maybe_unused]] IAccess iAccess,
Context& context)
{
constexpr auto tile_distribution = TileEncodingPattern::make_2d_static_tile_distribution();
block_sync_lds();
context.out_tile = load_tile(make_tile_window(context.lds_read_window, tile_distribution));
}
};
/// @brief Apply elementwise operation with auxiliary tensors
///
/// @tparam Elementwise Elementwise functor type
/// @tparam NumAux Number of auxiliary tensors to load and apply
///
/// @par Context Requirements
/// out_tile: In/out tile for elementwise operation
/// aux_windows: Tuple of auxiliary tensor windows
template <typename Elementwise, index_t NumAux>
struct ElementwiseOp
{
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename IAccess,
typename Context>
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
[[maybe_unused]] const AccTile& acc_tile,
[[maybe_unused]] const AuxWindows& aux_windows,
[[maybe_unused]] void* p_smem,
[[maybe_unused]] IAccess iAccess,
Context& context)
{
const auto aux_tiles = generate_tuple(
[&](auto idx) { return load_tile(context.aux_windows[idx]); }, number<NumAux>{});
const auto tiles = concat_tuple_of_reference(
tie(context.out_tile, context.out_tile),
generate_tie([&](auto idx) -> const auto& { return aux_tiles[idx]; },
number<NumAux>{}));
tile_elementwise_inout_unpack(Elementwise{}, tiles);
}
};
/// @brief Store output tile to global memory
///
/// @tparam MemOp Memory operation type (set or atomic_add)
///
/// @par Context Requirements
/// out_tile: Tile to store
template <memory_operation_enum MemOp>
struct StoreOp
{
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename IAccess,
typename Context>
CK_TILE_DEVICE void operator()(OutWindow& out_window,
[[maybe_unused]] const AccTile& acc_tile,
[[maybe_unused]] const AuxWindows& aux_windows,
[[maybe_unused]] void* p_smem,
[[maybe_unused]] IAccess iAccess,
Context& context)
{
if constexpr(MemOp == memory_operation_enum::set)
{
store_tile(out_window, context.out_tile);
}
else
{
update_tile(out_window, context.out_tile);
}
}
};
/// @brief Move output and auxiliary windows by step from space-filling curve
///
/// @tparam SFC Space filling curve type providing step computation
/// @tparam NumAux Number of auxiliary windows to move
///
/// @par Context Requirements
/// aux_windows: Tuple of windows to move
template <typename SFC, index_t NumAux>
struct MoveWindowsOp
{
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename IAccess,
typename Context>
CK_TILE_DEVICE void operator()(OutWindow& out_window,
[[maybe_unused]] const AccTile& acc_tile,
[[maybe_unused]] const AuxWindows& aux_windows,
[[maybe_unused]] void* p_smem,
IAccess iAccess,
Context& context)
{
constexpr index_t num_access = SFC::get_num_of_access();
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(out_window, {step.at(number<0>{}), step.at(number<1>{})});
static_for<0, NumAux, 1>{}([&](auto idx) {
move_tile_window(context.aux_windows[idx],
{step.at(number<0>{}), step.at(number<1>{})});
});
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,512 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include <optional>
namespace ck_tile {
//------------------------------------------------------------------------------
// CShuffle-specific epilogue operations
// These operations are specific to CShuffle epilogue due to its unique.
//------------------------------------------------------------------------------
/// @brief Slice accumulator tile for CShuffle epilogue
///
/// @par Purpose
/// Extracts a portion of the accumulator tile into the working tile
/// based on the current iteration index. This is CShuffle-specific.
///
/// @tparam SFC Space filling curve type
/// @tparam CWarpDstr Warp distribution type
/// @tparam NumMXdlPerWavePerShuffle XDL tiles in M per wave per shuffle
/// @tparam NumNXdlPerWavePerShuffle XDL tiles in N per wave per shuffle
/// @tparam MPerIterShuffle M elements per shuffle iteration
/// @tparam NPerIterShuffle N elements per shuffle iteration
template <typename SFC,
typename CWarpDstr,
index_t NumMXdlPerWavePerShuffle,
index_t NumNXdlPerWavePerShuffle,
index_t MPerIterShuffle,
index_t NPerIterShuffle>
struct CShuffleSliceOp
{
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename IAccess,
typename Context>
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
const AccTile& acc_tile,
[[maybe_unused]] const AuxWindows& aux_windows,
[[maybe_unused]] void* p_smem,
IAccess iAccess,
Context& context)
{
constexpr auto idx_start = SFC::get_index(iAccess);
constexpr auto m_iter = number<idx_start.at(number<0>{}) / MPerIterShuffle>{};
constexpr auto n_iter = number<idx_start.at(number<1>{}) / NPerIterShuffle>{};
constexpr auto warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
context.working_tile.get_thread_buffer() = acc_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<m_iter * NumMXdlPerWavePerShuffle, n_iter * NumNXdlPerWavePerShuffle>{},
warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
warp_y_lengths));
}
};
/// @brief Scale working tile using tensor windows (CShuffle-specific)
///
/// @par Purpose
/// Scales the working tile using row and column scale tensors.
/// CShuffle-specific because it creates scale windows from the
/// working tile's distribution and handles window movement.
///
/// @tparam SFC Space filling curve type
template <typename SFC>
struct CShuffleScaleWindowOp
{
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename IAccess,
typename Context,
typename ScaleRowTensor,
typename ScaleColTensor>
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
[[maybe_unused]] const AccTile& acc_tile,
[[maybe_unused]] const AuxWindows& aux_windows,
[[maybe_unused]] void* p_smem,
IAccess iAccess,
Context& context,
const ScaleRowTensor& scale_row_tensor,
const ScaleColTensor& scale_col_tensor)
{
auto scale_row_window =
make_tile_window(scale_row_tensor, context.working_tile.get_tile_distribution());
auto scale_col_window =
make_tile_window(scale_col_tensor, context.working_tile.get_tile_distribution());
const auto scale_row_tile = load_tile(scale_row_window);
const auto scale_col_tile = load_tile(scale_col_window);
tile_elementwise_inout(element_wise::MultiDMultiply{},
context.working_tile,
context.working_tile,
scale_row_tile,
scale_col_tile);
constexpr index_t num_access = SFC::get_num_of_access();
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(number<iAccess>{});
move_tile_window(scale_row_window, {step.at(number<0>{}), step.at(number<1>{})});
move_tile_window(scale_col_window, {step.at(number<0>{}), step.at(number<1>{})});
}
}
};
//------------------------------------------------------------------------------
// CShuffle problem and base operation definitions
//------------------------------------------------------------------------------
/// @brief Problem configuration for CShuffle epilogue chainer operations
/// @note Mirrors CShuffleEpilogueProblem but uses AsDataType/BsDataType for tuple support
template <typename AsDataType_,
typename BsDataType_,
typename DsDataType_,
typename AccDataType_,
typename ODataType_,
typename DsLayout_,
typename ELayout_,
typename CDElementwise_,
index_t kM_,
index_t kN_,
index_t MWave_,
index_t NWave_,
index_t MPerXdl_,
index_t NPerXdl_,
index_t KPerXdl_,
bool isCTransposed_,
memory_operation_enum MemoryOperation_,
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false,
index_t BlockedXDLN_PerWarp_ = 1>
struct CShuffleEpilogueChainProblem
{
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t MWave = MWave_;
static constexpr index_t NWave = NWave_;
static constexpr index_t MPerXdl = MPerXdl_;
static constexpr index_t NPerXdl = NPerXdl_;
static constexpr index_t KPerXdl = KPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
};
template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogueChainBaseOp
{
using Problem = remove_cvref_t<Problem_>;
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
remove_cvref_t<AsDataType>,
remove_cvref_t<tuple<AsDataType>>>;
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
remove_cvref_t<BsDataType>,
remove_cvref_t<tuple<BsDataType>>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using ELayout = remove_cvref_t<typename Problem::ELayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t MWave = Problem::MWave;
static constexpr index_t NWave = Problem::NWave;
static constexpr index_t MPerXdl = Problem::MPerXdl;
static constexpr index_t NPerXdl = Problem::NPerXdl;
static constexpr index_t KPerXdl = Problem::KPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
/**
* @brief Get the vector store size for C tensor.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
{
if constexpr(FixedVectorSize)
{
return VectorSizeC;
}
constexpr index_t max_vector_size = 16;
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(NPerIteration),
static_cast<int>(max_vector_size / sizeof(ODataType)));
}
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(MPerIteration),
static_cast<int>(max_vector_size / sizeof(ODataType)));
}
else
{
static_assert(false, "Unsupported ELayout!");
}
}
/**
* @brief Get the vector store size for Di tensor.
*
* @return The vector store size for Di tensor.
*/
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> index)
{
constexpr index_t max_vector_size = 16;
using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(NPerIteration),
static_cast<int>(max_vector_size / sizeof(DiDataType)));
}
else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(MPerIteration),
static_cast<int>(max_vector_size / sizeof(DiDataType)));
}
else
{
static_assert(false, "Unsupported DLayout!");
}
}
/**
* @brief Shuffle tile configuration parameters
*
* @details These parameters control the number of XDL tiles processed per wave in each shuffle
* iteration:
* - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave
* - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave
*/
static constexpr auto shuffle_tile_tuple = [] {
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
if constexpr(elem_per_thread >= GetVectorSizeC())
{
return std::make_tuple(1, 1);
}
else
{
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
(kMPerBlock % num_xdl_shuffles == 0),
"kMPerBlock must be divisible by MPerXdl*MWave and "
"num_xdl_shuffles for CShuffleEpilogueStageBase");
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
}
else
{
static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
(kNPerBlock % num_xdl_shuffles == 0),
"kNPerBlock must be divisible by NPerXdl*NWave and "
"num_xdl_shuffles for CShuffleEpilogueStageBase");
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
}
}
}();
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
static constexpr index_t NumNXdlPerWavePerShuffle =
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
static constexpr auto MNPerIterationShuffle = [] {
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
else
return std::make_tuple(m_val, n_val);
}();
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
using WG = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
AccDataType,
MPerXdl,
NPerXdl,
KPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{
// N is contiguous dimension
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<NPerIterationShuffle>{}, number<1>{}));
}
// M is contiguous dimension
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<1>{}, number<MPerIterationShuffle>{}));
}
else
{
static_assert(false, "Unsupported ELayout!");
}
}
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
{
constexpr auto block_outer_dstr_encoding = [] {
if constexpr(BlockedXDLN_PerWarp == 1)
{
return tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
}
else
{
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
// BlockedLayout
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
}
}();
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
return block_dstr_encoding;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
}
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<kBlockSize,
MPerIterationShuffle,
NPerIterationShuffle,
GetVectorSizeC(),
tile_distribution_pattern::thread_raked,
Problem::kNumWaveGroups>;
/// @brief Context structure for CShuffle epilogue operations
///
/// @par Purpose
/// The context serves as a shared workspace that maintains intermediate results
/// and resources across multiple epilogue operations. It eliminates the need for
/// operations to recreate shared data structures and enables data flow
/// through the operation graph.
///
/// @par Standardized Interface
/// Uses standardized member names so common operations can work with this context:
/// - working_tile: Intermediate tile for shuffle operations
/// - out_tile: Output tile for final results
/// - aux_windows: Auxiliary tensor windows (D tensors)
/// - lds_write_window: Window for writing to LDS
/// - lds_read_window: Window for reading from LDS
template <typename WorkingTileType,
typename LdsBlockType,
typename LdsWriteWindowType,
typename LdsReadWindowType,
typename AuxWindowsType,
typename OutTileType>
struct CShuffleContext
{
WorkingTileType working_tile; // Working tile for shuffle operations
LdsBlockType lds_block; // LDS block view
LdsWriteWindowType lds_write_window; // Window for writing to LDS
LdsReadWindowType lds_read_window; // Window for reading from LDS
AuxWindowsType aux_windows; // Auxiliary tensor windows (D tensors)
OutTileType out_tile; // Output tile
};
template <typename OutDramWindow, typename AccTile, typename DsDramWindows>
CK_TILE_DEVICE auto operator()([[maybe_unused]] OutDramWindow& out_window,
[[maybe_unused]] const AccTile& acc_tile,
const DsDramWindows& ds_windows,
void* p_smem)
{
static_assert(
std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
"Currently, the CShuffleEpilogueStageBase only supports the Row Major Output layout");
constexpr auto working_tile_distr =
make_static_tile_distribution(MakeLdsDistributionEncode());
auto working_tile = make_static_distributed_tensor<AccDataType>(working_tile_distr);
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
auto lds_block = make_tensor_view<address_space_enum::lds>(static_cast<ODataType*>(p_smem),
lds_block_desc);
auto lds_write_window = make_tile_window(
lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0},
working_tile_distr);
auto lds_read_window = make_tile_window(
lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
constexpr auto dram_tile_distribution =
TileEncodingPattern::make_2d_static_tile_distribution();
auto aux_windows = generate_tuple(
[&](auto idx) { return make_tile_window(ds_windows[idx], dram_tile_distribution); },
number<NumDTensor>{});
auto out_tile = load_tile(make_tile_window(lds_read_window, dram_tile_distribution));
using ContextType = CShuffleContext<decltype(working_tile),
decltype(lds_block),
decltype(lds_write_window),
decltype(lds_read_window),
decltype(aux_windows),
decltype(out_tile)>;
ContextType context;
context.working_tile = working_tile;
context.lds_block = lds_block;
context.lds_write_window = lds_write_window;
context.lds_read_window = lds_read_window;
context.aux_windows = aux_windows;
context.out_tile = out_tile;
return context;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,129 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp"
#include "ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp"
#include "ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp"
namespace ck_tile {
/// @brief Schedule type tags for epilogue selection
/// @par Purpose
/// Each tag corresponds to a pre-built schedule, these are used to select a schedule
/// Standard epilogue schedule: Slice → CastStore → Load → ApplyD → Store → Move
struct DefaultScheduleTag
{
};
/// RowCol quantization schedule: Slice → ScaleWindow → CastStore → Load → ApplyD → Store → Move
struct RowColQuantScheduleTag
{
};
/// Tensor quantization schedule: Slice → ScaleScalar → CastStore → Load → ApplyD → Store → Move
struct TensorQuantScheduleTag
{
};
/// @brief CShuffle epilogue scheduler providing pre-built schedules
///
/// @par Overview
/// CshuffleEpilogueSchedule acts as the scheduler component for EpilogueChainer.
/// It provides context creation and pre-built schedules. The scheduler
/// uses tags to select/create appropriate epilogue schedule.
///
/// @tparam Problem The epilogue problem configuration
/// @tparam ScheduleTag Tag selecting the epilogue schedule
template <typename Problem, typename ScheduleTag = DefaultScheduleTag>
struct CshuffleEpilogueSchedule
{
using ProblemType = Problem;
using BaseOp = CShuffleEpilogueChainBaseOp<Problem>;
static constexpr index_t NumAccess = BaseOp::SFC::get_num_of_access();
/// @brief Create context for epilogue operations
template <typename OutWindow, typename AccTile, typename AuxWindows>
CK_TILE_DEVICE static auto create_context(OutWindow& out_window,
const AccTile& acc_tile,
const AuxWindows& aux_windows,
void* p_smem)
{
return BaseOp{}(out_window, acc_tile, aux_windows, p_smem);
}
/// @brief Make schedule based on compile-time tag selection
template <typename... Args>
CK_TILE_DEVICE static auto make_schedule(Args&&... args)
{
if constexpr(std::is_same_v<ScheduleTag, DefaultScheduleTag>)
{
// Standard epilogue
// Schedule: Slice -> CastAndStoreLds -> Load -> ApplyD -> Store -> MoveWindows
static_assert(sizeof...(args) == 0, "DefaultSchedule expects no arguments");
return make_graph<NumAccess>(
make_node<CShuffleSliceOp<typename BaseOp::SFC,
typename BaseOp::CWarpDstr,
BaseOp::NumMXdlPerWavePerShuffle,
BaseOp::NumNXdlPerWavePerShuffle,
BaseOp::MPerIterationShuffle,
BaseOp::NPerIterationShuffle>>(),
make_node<CastAndStoreToLdsOp<typename BaseOp::ODataType>>(),
make_node<LoadFromLdsOp<typename BaseOp::TileEncodingPattern>>(),
make_node<ElementwiseOp<typename Problem::CDElementwise, Problem::NumDTensor>>(),
make_node<StoreOp<Problem::MemoryOperation>>(),
make_node<MoveWindowsOp<typename BaseOp::SFC, Problem::NumDTensor>>());
}
else if constexpr(std::is_same_v<ScheduleTag, RowColQuantScheduleTag>)
{
// RowCol quantization schedule with tensor windows
// Schedule: Slice -> ScaleWindow -> CastAndStoreLds -> Load -> ApplyD -> Store ->
// MoveWindows
static_assert(sizeof...(args) == 2,
"RowColQuantSchedule requires exactly 2 scale tensor arguments");
return make_graph<NumAccess>(
make_node<CShuffleSliceOp<typename BaseOp::SFC,
typename BaseOp::CWarpDstr,
BaseOp::NumMXdlPerWavePerShuffle,
BaseOp::NumNXdlPerWavePerShuffle,
BaseOp::MPerIterationShuffle,
BaseOp::NPerIterationShuffle>>(),
make_node<CShuffleScaleWindowOp<typename BaseOp::SFC>>(std::forward<Args>(args)...),
make_node<CastAndStoreToLdsOp<typename BaseOp::ODataType>>(),
make_node<LoadFromLdsOp<typename BaseOp::TileEncodingPattern>>(),
make_node<ElementwiseOp<typename Problem::CDElementwise, Problem::NumDTensor>>(),
make_node<StoreOp<Problem::MemoryOperation>>(),
make_node<MoveWindowsOp<typename BaseOp::SFC, Problem::NumDTensor>>());
}
else if constexpr(std::is_same_v<ScheduleTag, TensorQuantScheduleTag>)
{
// Tensor quantization schedule with scalar values
// Schedule: Slice -> ScaleScalar -> CastAndStoreLds -> Load -> ApplyD -> Store ->
// MoveWindows
static_assert(sizeof...(args) == 2,
"TensorQuantSchedule requires exactly 2 scalar arguments");
return make_graph<NumAccess>(
make_node<CShuffleSliceOp<typename BaseOp::SFC,
typename BaseOp::CWarpDstr,
BaseOp::NumMXdlPerWavePerShuffle,
BaseOp::NumNXdlPerWavePerShuffle,
BaseOp::MPerIterationShuffle,
BaseOp::NPerIterationShuffle>>(),
make_node<ScaleScalarOp>(std::forward<Args>(args)...),
make_node<CastAndStoreToLdsOp<typename BaseOp::ODataType>>(),
make_node<LoadFromLdsOp<typename BaseOp::TileEncodingPattern>>(),
make_node<ElementwiseOp<typename Problem::CDElementwise, Problem::NumDTensor>>(),
make_node<StoreOp<Problem::MemoryOperation>>(),
make_node<MoveWindowsOp<typename BaseOp::SFC, Problem::NumDTensor>>());
}
else
{
static_assert(false, "Unknown schedule tag");
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,213 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile {
/// @brief Epilogue Chainer - Modular epilogue processing facilitator
///
/// @par Overview
/// EpilogueChainer provides an interface for processing epilogue operations
/// through schedules. The chainer uses decomposed epilogue operations, these are
/// scheduled/sequenced by a Scheduler to form operation graphs.
///
/// @tparam Scheduler The schedule provider that defines epilogue operation graphs
template <typename Scheduler>
struct EpilogueChainer
{
using Problem = typename Scheduler::ProblemType;
using BaseOp = typename Scheduler::BaseOp;
using ODataType = typename BaseOp::ODataType;
using DsDataType = typename BaseOp::DsDataType;
using DsLayout = typename BaseOp::DsLayout;
using AccDataType = typename BaseOp::AccDataType;
static constexpr auto MemoryOperation = BaseOp::MemoryOperation;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return BaseOp::GetSmemSize(); }
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
{
return BaseOp::GetVectorSizeC();
}
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> idx)
{
return BaseOp::GetVectorSizeD(idx);
}
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
{
return BaseOp::MakeLdsDistributionEncode();
}
/// @brief Process epilogue through scheduler-defined operation graph
///
/// @par Flow
/// 1. Create shared context through scheduler
/// 2. Generate operation schedule based on arguments
/// 3. Run scheduled operations in sequence
template <typename OutWindow, typename AccTile, typename AuxWindows, typename... Args>
CK_TILE_DEVICE void operator()(OutWindow& out_window,
const AccTile& acc_tile,
const AuxWindows& aux_windows,
void* p_smem,
Args&&... args) const
{
// The context serves as a shared workspace that maintains intermediate results
// and resources across multiple epilogue operations.
auto context = Scheduler::create_context(out_window, acc_tile, aux_windows, p_smem);
auto schedule = Scheduler::make_schedule(std::forward<Args>(args)...);
schedule(out_window, acc_tile, aux_windows, p_smem, context);
}
};
/// @brief Epilogue operation wrapper with arguments
///
/// @par Purpose
/// EpilogueNode wraps individual epilogue operations with their required arguments,
/// allowing them to be composed into operation graphs. Arguments are captured at construction
/// time and automatically forwarded during processing.
///
/// @tparam EpilogueType Epilogue operation (e.g., SliceEpilogue, ScaleEpilogue)
/// @tparam Args Types of arguments required by the epilogue operation
template <typename EpilogueType, typename... Args>
struct EpilogueNode
{
using Epilogue = EpilogueType;
ck_tile::tuple<Args...> args;
constexpr EpilogueNode(Args... a) : args(a...) {}
/// @brief Process epilogue without iteration index
template <typename OutWindow, typename AccTile, typename AuxWindows, typename Context>
CK_TILE_DEVICE void operator()(OutWindow& out_window,
const AccTile& acc_tile,
const AuxWindows& aux_windows,
void* p_smem,
Context& context) const
{
ck_tile::apply(
[&](auto&&... epilogue_args) {
EpilogueType{}(out_window,
acc_tile,
aux_windows,
p_smem,
context,
std::forward<decltype(epilogue_args)>(epilogue_args)...);
},
args);
}
/// @brief Process epilogue with iteration index
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename Context,
index_t I>
CK_TILE_DEVICE void operator()(OutWindow& out_window,
const AccTile& acc_tile,
const AuxWindows& aux_windows,
void* p_smem,
Context& context,
number<I> iAccess) const
{
ck_tile::apply(
[&](auto&&... epilogue_args) {
EpilogueType{}(out_window,
acc_tile,
aux_windows,
p_smem,
iAccess,
context,
std::forward<decltype(epilogue_args)>(epilogue_args)...);
},
args);
}
};
/// @brief Specialization for epilogue operation wrapper with no arguments
template <typename EpilogueType>
struct EpilogueNode<EpilogueType>
{
using Epilogue = EpilogueType;
ck_tile::tuple<> args;
constexpr EpilogueNode() = default;
template <typename OutWindow, typename AccTile, typename AuxWindows, typename Context>
CK_TILE_DEVICE void operator()(OutWindow& out_window,
const AccTile& acc_tile,
const AuxWindows& aux_windows,
void* p_smem,
Context& context) const
{
EpilogueType{}(out_window, acc_tile, aux_windows, p_smem, context);
}
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename Context,
index_t I>
CK_TILE_DEVICE void operator()(OutWindow& out_window,
const AccTile& acc_tile,
const AuxWindows& aux_windows,
void* p_smem,
Context& context,
number<I> iAccess) const
{
EpilogueType{}(out_window, acc_tile, aux_windows, p_smem, iAccess, context);
}
};
/// @brief Operation graph that sequentially composes multiple epilogue operations
///
/// @tparam Steps Number of iterations
/// @tparam EpilogueTypes Types of epilogue nodes in the operation graph
template <index_t Steps, typename... EpilogueTypes>
struct EpilogueGraph
{
ck_tile::tuple<EpilogueTypes...> epilogues;
constexpr EpilogueGraph() = default;
constexpr EpilogueGraph(EpilogueTypes... eps) : epilogues(eps...) {}
/// @brief Process all epilogues for each iteration in sequence
template <typename OutWindow, typename AccTile, typename AuxWindows, typename Context>
CK_TILE_DEVICE void operator()(OutWindow& out_window,
const AccTile& acc_tile,
const AuxWindows& aux_windows,
void* p_smem,
Context& context) const
{
// For each iteration, process all epilogues in order
static_for<0, Steps, 1>{}([&](auto iAccess) {
static_for<0, sizeof...(EpilogueTypes), 1>{}([&](auto I) {
epilogues.template get<I.value>()(
out_window, acc_tile, aux_windows, p_smem, context, iAccess);
});
});
}
};
/// @brief Helper function for creating epilogue nodes
template <typename EpilogueType, typename... Args>
constexpr auto make_node(Args... args)
{
return EpilogueNode<EpilogueType, Args...>{args...};
}
/// @brief Helper function for creating operation graphs
template <index_t Steps, typename... EpilogueTypes>
constexpr auto make_graph(EpilogueTypes... epilogues)
{
return EpilogueGraph<Steps, EpilogueTypes...>{epilogues...};
}
} // namespace ck_tile