mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
[CK_TILE] Epilogue chaining (Lwpck 3373) (#2773)
* Epilogue chainer * epilogue chainer with context to share state in between epilogues * chain-able epilogues for cshuffle * clang-format * rebase related changes - Added separate chainer test - clang format * comment resolutions * clang-format * Policy based chaining - basic Policy structure to control blanket looping and barrier placement. - to be extended for fine grianed control - to be modified to move possible auto-compute values and SFC access count to policy * Refactoring as per spec - Introduced epilogue schedule, graph - modified chainer to function with graph and schedule * minor_changes - made functions to overload in the epilogue_graph file * clang-format * Documentation and Comments - Added comments to files - Noted changes in changelog - Added README to explain the chainer and current status, exact use steps to be added * Comment resolutions - README modified with the suggested changes - Comment fixed accordingly * major refactoring - modified the chainer files to match the new design - updated comments - updated readme - multi-d example shocases use of the chainer * minor cleanup * tensor and rowcol quant chainer epilogue - added scalarepilogue for tensor quant - added schedule for tensorquant - modified quant example to use chainer and appropriate schedules * Refactor epilogue chainer: generalize ops and standardize context interface Address review comments. Changes: - Rename CastToLdsOp to CastAndStoreToLdsOp for clarity - Standardize context member names (working_tile, out_tile, aux_windows) - Update README documentation with correct operation names - Clean up parameter naming in epilogue_chainer.hpp (OutWindow, AccTile, AuxWindows) - common_epilogue_ops.hpp: General-purpose ops (ScaleScalarOp, CastAndStoreToLdsOp, LoadFromLdsOp, ElementwiseOp, StoreOp, MoveWindowsOp) - cshuffle_epilogue_chainer_ops.hpp: CShuffle-specific context and slice operations - epilogue_chainer.hpp: Cleaned up parameter naming for generality - Removed test files that are no longer needed. These were added for intermediate use * update cshuffle chainer ops file w.r.t cshuffle_epilogue.hpp updates & add chainer to quant gemm example * fix compile errors - CI uses c++17 while the code had c++20 features --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
bfac64953f
commit
15e81397a4
@@ -2,6 +2,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp"
|
||||
#include "ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp"
|
||||
#include "ck_tile/ops/epilogue/chainer/cshuffle_epilogue_schedule.hpp"
|
||||
#include "ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp"
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
|
||||
61
include/ck_tile/ops/epilogue/chainer/README.md
Normal file
61
include/ck_tile/ops/epilogue/chainer/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# CK Tile Epilogue Chainer
|
||||
|
||||
## Overview
|
||||
|
||||
The Epilogue Chainer provides a modular epilogue processing framework through scheduler-defined operation graphs.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Design Principle
|
||||
The chainer follows a **Scheduler-Graph-Node** architecture with shared context:
|
||||
- **Scheduler**: Defines operation graphs and creates a shared context
|
||||
- **Graph**: Composes multiple operations into sequential processing units
|
||||
- **Node**: Wraps individual epilogue operations with their arguments
|
||||
|
||||
### EpilogueChainer
|
||||
The `EpilogueChainer` struct serves as the modular epilogue processing facilitator. It delegates to schedulers for context creation and schedule generation, then processes the resulting operation graphs.
|
||||
|
||||
### EpilogueNode
|
||||
Individual epilogue operations are wrapped in `EpilogueNode` structures that capture required arguments at construction time and automatically forward them during processing. Supports both parameterized and parameter-free operations.
|
||||
|
||||
### EpilogueGraph
|
||||
The `EpilogueGraph` composes multiple nodes into sequential processing units that iterate over multiple accesses if needed, running all operations in order for each iteration.
|
||||
|
||||
## Files
|
||||
|
||||
### Core Infrastructure
|
||||
- `epilogue_chainer.hpp` - General chainer, node, and graph infrastructure
|
||||
- `common_epilogue_ops.hpp` - Epilogue operations usable with any epilogue type
|
||||
|
||||
### CShuffle Implementation
|
||||
- `cshuffle_epilogue_chainer_ops.hpp` - CShuffle-specific problem, context, and slice operations
|
||||
- `cshuffle_epilogue_schedule.hpp` - CShuffle scheduler with pre-built schedules
|
||||
|
||||
## Usage
|
||||
|
||||
### Common Operations (common_epilogue_ops.hpp)
|
||||
These operations work with any context that provides the standardized interface:
|
||||
- `ScaleScalarOp` - Scale working-tile by scalar values
|
||||
- `CastAndStoreToLdsOp<DstType>` - Cast working-tile and store to LDS
|
||||
- `LoadFromLdsOp<Pattern>` - Load output tile from LDS with sync
|
||||
- `ElementwiseOp<Func, NumAux>` - Apply elementwise operation with auxiliary tensors
|
||||
- `StoreOp<MemOp>` - Store output tile to global memory
|
||||
- `MoveWindowsOp<SFC, NumAux>` - Advance windows to next position
|
||||
|
||||
### CShuffle-Specific Operations (cshuffle_epilogue_chainer_ops.hpp)
|
||||
These operations are specific to CShuffle epilogue:
|
||||
- `CShuffleSliceOp` - Slice accumulator tile based on distribution
|
||||
- `CShuffleScaleWindowOp` - Scale using tensor windows with shuffle distribution
|
||||
|
||||
### Context Interface
|
||||
Operations communicate through a shared context with standardized members:
|
||||
- `working_tile`: Tile for intermediate computations
|
||||
- `out_tile`: Output tile
|
||||
- `aux_windows`: Tuple of auxiliary tensor windows
|
||||
- `lds_write_window`: Window for writing to LDS
|
||||
- `lds_read_window`: Window for reading from LDS
|
||||
|
||||
### Schedule Tags
|
||||
- `DefaultScheduleTag` - Standard: Slice → CastStore → Load → ApplyD → Store → Move
|
||||
- `RowColQuantScheduleTag` - With window scaling
|
||||
- `TensorQuantScheduleTag` - With scalar scaling
|
||||
208
include/ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp
Normal file
208
include/ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp
Normal file
@@ -0,0 +1,208 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
/// @file common_epilogue_ops.hpp
|
||||
/// @brief Reusable simple epilogue operations which might be used to compose more complex one.
|
||||
///
|
||||
///
|
||||
/// @par Context Interface
|
||||
/// Operations expect the context to provide:
|
||||
/// - working_tile: Tile for intermediate computations
|
||||
/// - out_tile: Output tile for final results
|
||||
/// - aux_windows: Tuple of auxiliary tensor windows (e.g., D tensors)
|
||||
/// - lds_write_window: Window for writing to LDS (if using LDS)
|
||||
/// - lds_read_window: Window for reading from LDS (if using LDS)
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Scale working tile by scalar values
|
||||
///
|
||||
/// @par Context Requirements
|
||||
/// working_tile: Tile to scale
|
||||
struct ScaleScalarOp
|
||||
{
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename IAccess,
|
||||
typename Context,
|
||||
typename ScaleA,
|
||||
typename ScaleB>
|
||||
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
|
||||
[[maybe_unused]] const AccTile& acc_tile,
|
||||
[[maybe_unused]] const AuxWindows& aux_windows,
|
||||
[[maybe_unused]] void* p_smem,
|
||||
[[maybe_unused]] IAccess iAccess,
|
||||
Context& context,
|
||||
const ScaleA& scale_a,
|
||||
const ScaleB& scale_b)
|
||||
{
|
||||
tile_elementwise_inout([&](auto& elem) { elem = elem * scale_a * scale_b; },
|
||||
context.working_tile);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Cast working tile and store to LDS
|
||||
///
|
||||
/// @tparam DataType Target data type for casting
|
||||
///
|
||||
/// @par Context Requirements
|
||||
/// working_tile: Tile to cast
|
||||
/// lds_write_window: Window for writing to LDS
|
||||
template <typename DataType>
|
||||
struct CastAndStoreToLdsOp
|
||||
{
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename IAccess,
|
||||
typename Context>
|
||||
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
|
||||
[[maybe_unused]] const AccTile& acc_tile,
|
||||
[[maybe_unused]] const AuxWindows& aux_windows,
|
||||
[[maybe_unused]] void* p_smem,
|
||||
[[maybe_unused]] IAccess iAccess,
|
||||
Context& context)
|
||||
{
|
||||
const auto casted_tile = cast_tile<DataType>(context.working_tile);
|
||||
store_tile(context.lds_write_window, casted_tile);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Load output tile from LDS with synchronization
|
||||
///
|
||||
/// @tparam TileEncodingPattern Pattern for tile distribution
|
||||
///
|
||||
/// @par Context Requirements
|
||||
/// lds_read_window: Window for reading from LDS
|
||||
/// out_tile: Destination for loaded tile
|
||||
template <typename TileEncodingPattern>
|
||||
struct LoadFromLdsOp
|
||||
{
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename IAccess,
|
||||
typename Context>
|
||||
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
|
||||
[[maybe_unused]] const AccTile& acc_tile,
|
||||
[[maybe_unused]] const AuxWindows& aux_windows,
|
||||
[[maybe_unused]] void* p_smem,
|
||||
[[maybe_unused]] IAccess iAccess,
|
||||
Context& context)
|
||||
{
|
||||
constexpr auto tile_distribution = TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
block_sync_lds();
|
||||
context.out_tile = load_tile(make_tile_window(context.lds_read_window, tile_distribution));
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Apply elementwise operation with auxiliary tensors
|
||||
///
|
||||
/// @tparam Elementwise Elementwise functor type
|
||||
/// @tparam NumAux Number of auxiliary tensors to load and apply
|
||||
///
|
||||
/// @par Context Requirements
|
||||
/// out_tile: In/out tile for elementwise operation
|
||||
/// aux_windows: Tuple of auxiliary tensor windows
|
||||
template <typename Elementwise, index_t NumAux>
|
||||
struct ElementwiseOp
|
||||
{
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename IAccess,
|
||||
typename Context>
|
||||
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
|
||||
[[maybe_unused]] const AccTile& acc_tile,
|
||||
[[maybe_unused]] const AuxWindows& aux_windows,
|
||||
[[maybe_unused]] void* p_smem,
|
||||
[[maybe_unused]] IAccess iAccess,
|
||||
Context& context)
|
||||
{
|
||||
const auto aux_tiles = generate_tuple(
|
||||
[&](auto idx) { return load_tile(context.aux_windows[idx]); }, number<NumAux>{});
|
||||
|
||||
const auto tiles = concat_tuple_of_reference(
|
||||
tie(context.out_tile, context.out_tile),
|
||||
generate_tie([&](auto idx) -> const auto& { return aux_tiles[idx]; },
|
||||
number<NumAux>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(Elementwise{}, tiles);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Store output tile to global memory
|
||||
///
|
||||
/// @tparam MemOp Memory operation type (set or atomic_add)
|
||||
///
|
||||
/// @par Context Requirements
|
||||
/// out_tile: Tile to store
|
||||
template <memory_operation_enum MemOp>
|
||||
struct StoreOp
|
||||
{
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename IAccess,
|
||||
typename Context>
|
||||
CK_TILE_DEVICE void operator()(OutWindow& out_window,
|
||||
[[maybe_unused]] const AccTile& acc_tile,
|
||||
[[maybe_unused]] const AuxWindows& aux_windows,
|
||||
[[maybe_unused]] void* p_smem,
|
||||
[[maybe_unused]] IAccess iAccess,
|
||||
Context& context)
|
||||
{
|
||||
if constexpr(MemOp == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_window, context.out_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_window, context.out_tile);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Move output and auxiliary windows by step from space-filling curve
|
||||
///
|
||||
/// @tparam SFC Space filling curve type providing step computation
|
||||
/// @tparam NumAux Number of auxiliary windows to move
|
||||
///
|
||||
/// @par Context Requirements
|
||||
/// aux_windows: Tuple of windows to move
|
||||
template <typename SFC, index_t NumAux>
|
||||
struct MoveWindowsOp
|
||||
{
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename IAccess,
|
||||
typename Context>
|
||||
CK_TILE_DEVICE void operator()(OutWindow& out_window,
|
||||
[[maybe_unused]] const AccTile& acc_tile,
|
||||
[[maybe_unused]] const AuxWindows& aux_windows,
|
||||
[[maybe_unused]] void* p_smem,
|
||||
IAccess iAccess,
|
||||
Context& context)
|
||||
{
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
move_tile_window(out_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
|
||||
static_for<0, NumAux, 1>{}([&](auto idx) {
|
||||
move_tile_window(context.aux_windows[idx],
|
||||
{step.at(number<0>{}), step.at(number<1>{})});
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,512 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// CShuffle-specific epilogue operations
|
||||
// These operations are specific to CShuffle epilogue due to its unique.
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
/// @brief Slice accumulator tile for CShuffle epilogue
|
||||
///
|
||||
/// @par Purpose
|
||||
/// Extracts a portion of the accumulator tile into the working tile
|
||||
/// based on the current iteration index. This is CShuffle-specific.
|
||||
///
|
||||
/// @tparam SFC Space filling curve type
|
||||
/// @tparam CWarpDstr Warp distribution type
|
||||
/// @tparam NumMXdlPerWavePerShuffle XDL tiles in M per wave per shuffle
|
||||
/// @tparam NumNXdlPerWavePerShuffle XDL tiles in N per wave per shuffle
|
||||
/// @tparam MPerIterShuffle M elements per shuffle iteration
|
||||
/// @tparam NPerIterShuffle N elements per shuffle iteration
|
||||
template <typename SFC,
|
||||
typename CWarpDstr,
|
||||
index_t NumMXdlPerWavePerShuffle,
|
||||
index_t NumNXdlPerWavePerShuffle,
|
||||
index_t MPerIterShuffle,
|
||||
index_t NPerIterShuffle>
|
||||
struct CShuffleSliceOp
|
||||
{
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename IAccess,
|
||||
typename Context>
|
||||
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
|
||||
const AccTile& acc_tile,
|
||||
[[maybe_unused]] const AuxWindows& aux_windows,
|
||||
[[maybe_unused]] void* p_smem,
|
||||
IAccess iAccess,
|
||||
Context& context)
|
||||
{
|
||||
constexpr auto idx_start = SFC::get_index(iAccess);
|
||||
constexpr auto m_iter = number<idx_start.at(number<0>{}) / MPerIterShuffle>{};
|
||||
constexpr auto n_iter = number<idx_start.at(number<1>{}) / NPerIterShuffle>{};
|
||||
|
||||
constexpr auto warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
context.working_tile.get_thread_buffer() = acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<m_iter * NumMXdlPerWavePerShuffle, n_iter * NumNXdlPerWavePerShuffle>{},
|
||||
warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
warp_y_lengths));
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Scale working tile using tensor windows (CShuffle-specific)
|
||||
///
|
||||
/// @par Purpose
|
||||
/// Scales the working tile using row and column scale tensors.
|
||||
/// CShuffle-specific because it creates scale windows from the
|
||||
/// working tile's distribution and handles window movement.
|
||||
///
|
||||
/// @tparam SFC Space filling curve type
|
||||
template <typename SFC>
|
||||
struct CShuffleScaleWindowOp
|
||||
{
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename IAccess,
|
||||
typename Context,
|
||||
typename ScaleRowTensor,
|
||||
typename ScaleColTensor>
|
||||
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
|
||||
[[maybe_unused]] const AccTile& acc_tile,
|
||||
[[maybe_unused]] const AuxWindows& aux_windows,
|
||||
[[maybe_unused]] void* p_smem,
|
||||
IAccess iAccess,
|
||||
Context& context,
|
||||
const ScaleRowTensor& scale_row_tensor,
|
||||
const ScaleColTensor& scale_col_tensor)
|
||||
{
|
||||
auto scale_row_window =
|
||||
make_tile_window(scale_row_tensor, context.working_tile.get_tile_distribution());
|
||||
auto scale_col_window =
|
||||
make_tile_window(scale_col_tensor, context.working_tile.get_tile_distribution());
|
||||
|
||||
const auto scale_row_tile = load_tile(scale_row_window);
|
||||
const auto scale_col_tile = load_tile(scale_col_window);
|
||||
|
||||
tile_elementwise_inout(element_wise::MultiDMultiply{},
|
||||
context.working_tile,
|
||||
context.working_tile,
|
||||
scale_row_tile,
|
||||
scale_col_tile);
|
||||
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(number<iAccess>{});
|
||||
move_tile_window(scale_row_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
move_tile_window(scale_col_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// CShuffle problem and base operation definitions
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
/// @brief Problem configuration for CShuffle epilogue chainer operations
|
||||
/// @note Mirrors CShuffleEpilogueProblem but uses AsDataType/BsDataType for tuple support
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename DsDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename CDElementwise_,
|
||||
index_t kM_,
|
||||
index_t kN_,
|
||||
index_t MWave_,
|
||||
index_t NWave_,
|
||||
index_t MPerXdl_,
|
||||
index_t NPerXdl_,
|
||||
index_t KPerXdl_,
|
||||
bool isCTransposed_,
|
||||
memory_operation_enum MemoryOperation_,
|
||||
index_t kNumWaveGroups_ = 1,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1,
|
||||
bool TiledMMAPermuteN_ = false,
|
||||
index_t BlockedXDLN_PerWarp_ = 1>
|
||||
struct CShuffleEpilogueChainProblem
|
||||
{
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using DsDataType = remove_cvref_t<DsDataType_>;
|
||||
using DsLayout = remove_cvref_t<DsLayout_>;
|
||||
using ELayout = remove_cvref_t<ELayout_>;
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
static constexpr index_t MWave = MWave_;
|
||||
static constexpr index_t NWave = NWave_;
|
||||
static constexpr index_t MPerXdl = MPerXdl_;
|
||||
static constexpr index_t NPerXdl = NPerXdl_;
|
||||
static constexpr index_t KPerXdl = KPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
|
||||
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct CShuffleEpilogueChainBaseOp
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
|
||||
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
|
||||
|
||||
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
|
||||
|
||||
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<AsDataType>,
|
||||
remove_cvref_t<tuple<AsDataType>>>;
|
||||
|
||||
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<BsDataType>,
|
||||
remove_cvref_t<tuple<BsDataType>>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
|
||||
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using ELayout = remove_cvref_t<typename Problem::ELayout>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr index_t MWave = Problem::MWave;
|
||||
static constexpr index_t NWave = Problem::NWave;
|
||||
static constexpr index_t MPerXdl = Problem::MPerXdl;
|
||||
static constexpr index_t NPerXdl = Problem::NPerXdl;
|
||||
static constexpr index_t KPerXdl = Problem::KPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t NumDTensor = Problem::NumDTensor;
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
|
||||
/**
|
||||
* @brief Get the vector store size for C tensor.
|
||||
*
|
||||
* @note The vector store size for output C tensor would depend on multiple factors
|
||||
* like its data layout and warp gemm C transposition. In general it would
|
||||
* be the number of consecutive elements in contiguous C dimension hold by
|
||||
* single thread.
|
||||
*
|
||||
* @return The vector store size for C tensor.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
|
||||
{
|
||||
if constexpr(FixedVectorSize)
|
||||
{
|
||||
return VectorSizeC;
|
||||
}
|
||||
constexpr index_t max_vector_size = 16;
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(NPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(ODataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(MPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(ODataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported ELayout!");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the vector store size for Di tensor.
|
||||
*
|
||||
* @return The vector store size for Di tensor.
|
||||
*/
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> index)
|
||||
{
|
||||
constexpr index_t max_vector_size = 16;
|
||||
using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(NPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(DiDataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(MPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(DiDataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported DLayout!");
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Shuffle tile configuration parameters
|
||||
*
|
||||
* @details These parameters control the number of XDL tiles processed per wave in each shuffle
|
||||
* iteration:
|
||||
* - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave
|
||||
* - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave
|
||||
*/
|
||||
static constexpr auto shuffle_tile_tuple = [] {
|
||||
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
|
||||
if constexpr(elem_per_thread >= GetVectorSizeC())
|
||||
{
|
||||
return std::make_tuple(1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
|
||||
(kMPerBlock % num_xdl_shuffles == 0),
|
||||
"kMPerBlock must be divisible by MPerXdl*MWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogueStageBase");
|
||||
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
|
||||
(kNPerBlock % num_xdl_shuffles == 0),
|
||||
"kNPerBlock must be divisible by NPerXdl*NWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogueStageBase");
|
||||
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
|
||||
}
|
||||
}
|
||||
}();
|
||||
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
|
||||
static constexpr index_t NumNXdlPerWavePerShuffle =
|
||||
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
|
||||
|
||||
static constexpr auto MNPerIterationShuffle = [] {
|
||||
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
|
||||
constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
|
||||
if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
|
||||
return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
|
||||
else
|
||||
return std::make_tuple(m_val, n_val);
|
||||
}();
|
||||
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
|
||||
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
|
||||
|
||||
using WG = WarpGemmDispatcher<ATypeToUse,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
KPerXdl,
|
||||
isCTransposed>;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
|
||||
{
|
||||
// N is contiguous dimension
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<NPerIterationShuffle>{}, number<1>{}));
|
||||
}
|
||||
// M is contiguous dimension
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<1>{}, number<MPerIterationShuffle>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported ELayout!");
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
|
||||
{
|
||||
constexpr auto block_outer_dstr_encoding = [] {
|
||||
if constexpr(BlockedXDLN_PerWarp == 1)
|
||||
{
|
||||
return tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<NumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
|
||||
// BlockedLayout
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{};
|
||||
}
|
||||
}();
|
||||
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
|
||||
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
|
||||
|
||||
return block_dstr_encoding;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
|
||||
}
|
||||
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
|
||||
/// @brief Context structure for CShuffle epilogue operations
|
||||
///
|
||||
/// @par Purpose
|
||||
/// The context serves as a shared workspace that maintains intermediate results
|
||||
/// and resources across multiple epilogue operations. It eliminates the need for
|
||||
/// operations to recreate shared data structures and enables data flow
|
||||
/// through the operation graph.
|
||||
///
|
||||
/// @par Standardized Interface
|
||||
/// Uses standardized member names so common operations can work with this context:
|
||||
/// - working_tile: Intermediate tile for shuffle operations
|
||||
/// - out_tile: Output tile for final results
|
||||
/// - aux_windows: Auxiliary tensor windows (D tensors)
|
||||
/// - lds_write_window: Window for writing to LDS
|
||||
/// - lds_read_window: Window for reading from LDS
|
||||
template <typename WorkingTileType,
|
||||
typename LdsBlockType,
|
||||
typename LdsWriteWindowType,
|
||||
typename LdsReadWindowType,
|
||||
typename AuxWindowsType,
|
||||
typename OutTileType>
|
||||
struct CShuffleContext
|
||||
{
|
||||
WorkingTileType working_tile; // Working tile for shuffle operations
|
||||
LdsBlockType lds_block; // LDS block view
|
||||
LdsWriteWindowType lds_write_window; // Window for writing to LDS
|
||||
LdsReadWindowType lds_read_window; // Window for reading from LDS
|
||||
AuxWindowsType aux_windows; // Auxiliary tensor windows (D tensors)
|
||||
OutTileType out_tile; // Output tile
|
||||
};
|
||||
|
||||
template <typename OutDramWindow, typename AccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto operator()([[maybe_unused]] OutDramWindow& out_window,
|
||||
[[maybe_unused]] const AccTile& acc_tile,
|
||||
const DsDramWindows& ds_windows,
|
||||
void* p_smem)
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
"Currently, the CShuffleEpilogueStageBase only supports the Row Major Output layout");
|
||||
|
||||
constexpr auto working_tile_distr =
|
||||
make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
auto working_tile = make_static_distributed_tensor<AccDataType>(working_tile_distr);
|
||||
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
auto lds_block = make_tensor_view<address_space_enum::lds>(static_cast<ODataType*>(p_smem),
|
||||
lds_block_desc);
|
||||
|
||||
auto lds_write_window = make_tile_window(
|
||||
lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0},
|
||||
working_tile_distr);
|
||||
|
||||
auto lds_read_window = make_tile_window(
|
||||
lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
constexpr auto dram_tile_distribution =
|
||||
TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
auto aux_windows = generate_tuple(
|
||||
[&](auto idx) { return make_tile_window(ds_windows[idx], dram_tile_distribution); },
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto out_tile = load_tile(make_tile_window(lds_read_window, dram_tile_distribution));
|
||||
|
||||
using ContextType = CShuffleContext<decltype(working_tile),
|
||||
decltype(lds_block),
|
||||
decltype(lds_write_window),
|
||||
decltype(lds_read_window),
|
||||
decltype(aux_windows),
|
||||
decltype(out_tile)>;
|
||||
|
||||
ContextType context;
|
||||
context.working_tile = working_tile;
|
||||
context.lds_block = lds_block;
|
||||
context.lds_write_window = lds_write_window;
|
||||
context.lds_read_window = lds_read_window;
|
||||
context.aux_windows = aux_windows;
|
||||
context.out_tile = out_tile;
|
||||
|
||||
return context;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,129 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp"
|
||||
#include "ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp"
|
||||
#include "ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Schedule type tags for epilogue selection
|
||||
/// @par Purpose
|
||||
/// Each tag corresponds to a pre-built schedule, these are used to select a schedule
|
||||
|
||||
/// Standard epilogue schedule: Slice → CastStore → Load → ApplyD → Store → Move
|
||||
struct DefaultScheduleTag
|
||||
{
|
||||
};
|
||||
|
||||
/// RowCol quantization schedule: Slice → ScaleWindow → CastStore → Load → ApplyD → Store → Move
|
||||
struct RowColQuantScheduleTag
|
||||
{
|
||||
};
|
||||
|
||||
/// Tensor quantization schedule: Slice → ScaleScalar → CastStore → Load → ApplyD → Store → Move
|
||||
struct TensorQuantScheduleTag
|
||||
{
|
||||
};
|
||||
|
||||
/// @brief CShuffle epilogue scheduler providing pre-built schedules
|
||||
///
|
||||
/// @par Overview
|
||||
/// CshuffleEpilogueSchedule acts as the scheduler component for EpilogueChainer.
|
||||
/// It provides context creation and pre-built schedules. The scheduler
|
||||
/// uses tags to select/create appropriate epilogue schedule.
|
||||
///
|
||||
/// @tparam Problem The epilogue problem configuration
|
||||
/// @tparam ScheduleTag Tag selecting the epilogue schedule
|
||||
template <typename Problem, typename ScheduleTag = DefaultScheduleTag>
|
||||
struct CshuffleEpilogueSchedule
|
||||
{
|
||||
using ProblemType = Problem;
|
||||
using BaseOp = CShuffleEpilogueChainBaseOp<Problem>;
|
||||
|
||||
static constexpr index_t NumAccess = BaseOp::SFC::get_num_of_access();
|
||||
|
||||
/// @brief Create context for epilogue operations
|
||||
template <typename OutWindow, typename AccTile, typename AuxWindows>
|
||||
CK_TILE_DEVICE static auto create_context(OutWindow& out_window,
|
||||
const AccTile& acc_tile,
|
||||
const AuxWindows& aux_windows,
|
||||
void* p_smem)
|
||||
{
|
||||
return BaseOp{}(out_window, acc_tile, aux_windows, p_smem);
|
||||
}
|
||||
|
||||
/// @brief Make schedule based on compile-time tag selection
|
||||
template <typename... Args>
|
||||
CK_TILE_DEVICE static auto make_schedule(Args&&... args)
|
||||
{
|
||||
if constexpr(std::is_same_v<ScheduleTag, DefaultScheduleTag>)
|
||||
{
|
||||
// Standard epilogue
|
||||
// Schedule: Slice -> CastAndStoreLds -> Load -> ApplyD -> Store -> MoveWindows
|
||||
static_assert(sizeof...(args) == 0, "DefaultSchedule expects no arguments");
|
||||
return make_graph<NumAccess>(
|
||||
make_node<CShuffleSliceOp<typename BaseOp::SFC,
|
||||
typename BaseOp::CWarpDstr,
|
||||
BaseOp::NumMXdlPerWavePerShuffle,
|
||||
BaseOp::NumNXdlPerWavePerShuffle,
|
||||
BaseOp::MPerIterationShuffle,
|
||||
BaseOp::NPerIterationShuffle>>(),
|
||||
make_node<CastAndStoreToLdsOp<typename BaseOp::ODataType>>(),
|
||||
make_node<LoadFromLdsOp<typename BaseOp::TileEncodingPattern>>(),
|
||||
make_node<ElementwiseOp<typename Problem::CDElementwise, Problem::NumDTensor>>(),
|
||||
make_node<StoreOp<Problem::MemoryOperation>>(),
|
||||
make_node<MoveWindowsOp<typename BaseOp::SFC, Problem::NumDTensor>>());
|
||||
}
|
||||
else if constexpr(std::is_same_v<ScheduleTag, RowColQuantScheduleTag>)
|
||||
{
|
||||
// RowCol quantization schedule with tensor windows
|
||||
// Schedule: Slice -> ScaleWindow -> CastAndStoreLds -> Load -> ApplyD -> Store ->
|
||||
// MoveWindows
|
||||
static_assert(sizeof...(args) == 2,
|
||||
"RowColQuantSchedule requires exactly 2 scale tensor arguments");
|
||||
return make_graph<NumAccess>(
|
||||
make_node<CShuffleSliceOp<typename BaseOp::SFC,
|
||||
typename BaseOp::CWarpDstr,
|
||||
BaseOp::NumMXdlPerWavePerShuffle,
|
||||
BaseOp::NumNXdlPerWavePerShuffle,
|
||||
BaseOp::MPerIterationShuffle,
|
||||
BaseOp::NPerIterationShuffle>>(),
|
||||
make_node<CShuffleScaleWindowOp<typename BaseOp::SFC>>(std::forward<Args>(args)...),
|
||||
make_node<CastAndStoreToLdsOp<typename BaseOp::ODataType>>(),
|
||||
make_node<LoadFromLdsOp<typename BaseOp::TileEncodingPattern>>(),
|
||||
make_node<ElementwiseOp<typename Problem::CDElementwise, Problem::NumDTensor>>(),
|
||||
make_node<StoreOp<Problem::MemoryOperation>>(),
|
||||
make_node<MoveWindowsOp<typename BaseOp::SFC, Problem::NumDTensor>>());
|
||||
}
|
||||
else if constexpr(std::is_same_v<ScheduleTag, TensorQuantScheduleTag>)
|
||||
{
|
||||
// Tensor quantization schedule with scalar values
|
||||
// Schedule: Slice -> ScaleScalar -> CastAndStoreLds -> Load -> ApplyD -> Store ->
|
||||
// MoveWindows
|
||||
static_assert(sizeof...(args) == 2,
|
||||
"TensorQuantSchedule requires exactly 2 scalar arguments");
|
||||
return make_graph<NumAccess>(
|
||||
make_node<CShuffleSliceOp<typename BaseOp::SFC,
|
||||
typename BaseOp::CWarpDstr,
|
||||
BaseOp::NumMXdlPerWavePerShuffle,
|
||||
BaseOp::NumNXdlPerWavePerShuffle,
|
||||
BaseOp::MPerIterationShuffle,
|
||||
BaseOp::NPerIterationShuffle>>(),
|
||||
make_node<ScaleScalarOp>(std::forward<Args>(args)...),
|
||||
make_node<CastAndStoreToLdsOp<typename BaseOp::ODataType>>(),
|
||||
make_node<LoadFromLdsOp<typename BaseOp::TileEncodingPattern>>(),
|
||||
make_node<ElementwiseOp<typename Problem::CDElementwise, Problem::NumDTensor>>(),
|
||||
make_node<StoreOp<Problem::MemoryOperation>>(),
|
||||
make_node<MoveWindowsOp<typename BaseOp::SFC, Problem::NumDTensor>>());
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown schedule tag");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
213
include/ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp
Normal file
213
include/ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp
Normal file
@@ -0,0 +1,213 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Epilogue Chainer - Modular epilogue processing facilitator
|
||||
///
|
||||
/// @par Overview
|
||||
/// EpilogueChainer provides an interface for processing epilogue operations
|
||||
/// through schedules. The chainer uses decomposed epilogue operations, these are
|
||||
/// scheduled/sequenced by a Scheduler to form operation graphs.
|
||||
///
|
||||
/// @tparam Scheduler The schedule provider that defines epilogue operation graphs
|
||||
template <typename Scheduler>
|
||||
struct EpilogueChainer
|
||||
{
|
||||
using Problem = typename Scheduler::ProblemType;
|
||||
using BaseOp = typename Scheduler::BaseOp;
|
||||
|
||||
using ODataType = typename BaseOp::ODataType;
|
||||
using DsDataType = typename BaseOp::DsDataType;
|
||||
using DsLayout = typename BaseOp::DsLayout;
|
||||
using AccDataType = typename BaseOp::AccDataType;
|
||||
static constexpr auto MemoryOperation = BaseOp::MemoryOperation;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return BaseOp::GetSmemSize(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
|
||||
{
|
||||
return BaseOp::GetVectorSizeC();
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> idx)
|
||||
{
|
||||
return BaseOp::GetVectorSizeD(idx);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
|
||||
{
|
||||
return BaseOp::MakeLdsDistributionEncode();
|
||||
}
|
||||
|
||||
/// @brief Process epilogue through scheduler-defined operation graph
|
||||
///
|
||||
/// @par Flow
|
||||
/// 1. Create shared context through scheduler
|
||||
/// 2. Generate operation schedule based on arguments
|
||||
/// 3. Run scheduled operations in sequence
|
||||
template <typename OutWindow, typename AccTile, typename AuxWindows, typename... Args>
|
||||
CK_TILE_DEVICE void operator()(OutWindow& out_window,
|
||||
const AccTile& acc_tile,
|
||||
const AuxWindows& aux_windows,
|
||||
void* p_smem,
|
||||
Args&&... args) const
|
||||
{
|
||||
// The context serves as a shared workspace that maintains intermediate results
|
||||
// and resources across multiple epilogue operations.
|
||||
auto context = Scheduler::create_context(out_window, acc_tile, aux_windows, p_smem);
|
||||
auto schedule = Scheduler::make_schedule(std::forward<Args>(args)...);
|
||||
schedule(out_window, acc_tile, aux_windows, p_smem, context);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Epilogue operation wrapper with arguments
|
||||
///
|
||||
/// @par Purpose
|
||||
/// EpilogueNode wraps individual epilogue operations with their required arguments,
|
||||
/// allowing them to be composed into operation graphs. Arguments are captured at construction
|
||||
/// time and automatically forwarded during processing.
|
||||
///
|
||||
/// @tparam EpilogueType Epilogue operation (e.g., SliceEpilogue, ScaleEpilogue)
|
||||
/// @tparam Args Types of arguments required by the epilogue operation
|
||||
template <typename EpilogueType, typename... Args>
|
||||
struct EpilogueNode
|
||||
{
|
||||
using Epilogue = EpilogueType;
|
||||
ck_tile::tuple<Args...> args;
|
||||
|
||||
constexpr EpilogueNode(Args... a) : args(a...) {}
|
||||
|
||||
/// @brief Process epilogue without iteration index
|
||||
template <typename OutWindow, typename AccTile, typename AuxWindows, typename Context>
|
||||
CK_TILE_DEVICE void operator()(OutWindow& out_window,
|
||||
const AccTile& acc_tile,
|
||||
const AuxWindows& aux_windows,
|
||||
void* p_smem,
|
||||
Context& context) const
|
||||
{
|
||||
ck_tile::apply(
|
||||
[&](auto&&... epilogue_args) {
|
||||
EpilogueType{}(out_window,
|
||||
acc_tile,
|
||||
aux_windows,
|
||||
p_smem,
|
||||
context,
|
||||
std::forward<decltype(epilogue_args)>(epilogue_args)...);
|
||||
},
|
||||
args);
|
||||
}
|
||||
|
||||
/// @brief Process epilogue with iteration index
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename Context,
|
||||
index_t I>
|
||||
CK_TILE_DEVICE void operator()(OutWindow& out_window,
|
||||
const AccTile& acc_tile,
|
||||
const AuxWindows& aux_windows,
|
||||
void* p_smem,
|
||||
Context& context,
|
||||
number<I> iAccess) const
|
||||
{
|
||||
ck_tile::apply(
|
||||
[&](auto&&... epilogue_args) {
|
||||
EpilogueType{}(out_window,
|
||||
acc_tile,
|
||||
aux_windows,
|
||||
p_smem,
|
||||
iAccess,
|
||||
context,
|
||||
std::forward<decltype(epilogue_args)>(epilogue_args)...);
|
||||
},
|
||||
args);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Specialization for epilogue operation wrapper with no arguments
|
||||
template <typename EpilogueType>
|
||||
struct EpilogueNode<EpilogueType>
|
||||
{
|
||||
using Epilogue = EpilogueType;
|
||||
ck_tile::tuple<> args;
|
||||
|
||||
constexpr EpilogueNode() = default;
|
||||
|
||||
template <typename OutWindow, typename AccTile, typename AuxWindows, typename Context>
|
||||
CK_TILE_DEVICE void operator()(OutWindow& out_window,
|
||||
const AccTile& acc_tile,
|
||||
const AuxWindows& aux_windows,
|
||||
void* p_smem,
|
||||
Context& context) const
|
||||
{
|
||||
EpilogueType{}(out_window, acc_tile, aux_windows, p_smem, context);
|
||||
}
|
||||
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename Context,
|
||||
index_t I>
|
||||
CK_TILE_DEVICE void operator()(OutWindow& out_window,
|
||||
const AccTile& acc_tile,
|
||||
const AuxWindows& aux_windows,
|
||||
void* p_smem,
|
||||
Context& context,
|
||||
number<I> iAccess) const
|
||||
{
|
||||
EpilogueType{}(out_window, acc_tile, aux_windows, p_smem, iAccess, context);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Operation graph that sequentially composes multiple epilogue operations
|
||||
///
|
||||
/// @tparam Steps Number of iterations
|
||||
/// @tparam EpilogueTypes Types of epilogue nodes in the operation graph
|
||||
template <index_t Steps, typename... EpilogueTypes>
|
||||
struct EpilogueGraph
|
||||
{
|
||||
ck_tile::tuple<EpilogueTypes...> epilogues;
|
||||
|
||||
constexpr EpilogueGraph() = default;
|
||||
constexpr EpilogueGraph(EpilogueTypes... eps) : epilogues(eps...) {}
|
||||
|
||||
/// @brief Process all epilogues for each iteration in sequence
|
||||
template <typename OutWindow, typename AccTile, typename AuxWindows, typename Context>
|
||||
CK_TILE_DEVICE void operator()(OutWindow& out_window,
|
||||
const AccTile& acc_tile,
|
||||
const AuxWindows& aux_windows,
|
||||
void* p_smem,
|
||||
Context& context) const
|
||||
{
|
||||
// For each iteration, process all epilogues in order
|
||||
static_for<0, Steps, 1>{}([&](auto iAccess) {
|
||||
static_for<0, sizeof...(EpilogueTypes), 1>{}([&](auto I) {
|
||||
epilogues.template get<I.value>()(
|
||||
out_window, acc_tile, aux_windows, p_smem, context, iAccess);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Helper function for creating epilogue nodes
|
||||
template <typename EpilogueType, typename... Args>
|
||||
constexpr auto make_node(Args... args)
|
||||
{
|
||||
return EpilogueNode<EpilogueType, Args...>{args...};
|
||||
}
|
||||
|
||||
/// @brief Helper function for creating operation graphs
|
||||
template <index_t Steps, typename... EpilogueTypes>
|
||||
constexpr auto make_graph(EpilogueTypes... epilogues)
|
||||
{
|
||||
return EpilogueGraph<Steps, EpilogueTypes...>{epilogues...};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user