mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
This commit is contained in:
61
include/ck_tile/ops/epilogue/chainer/README.md
Normal file
61
include/ck_tile/ops/epilogue/chainer/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# CK Tile Epilogue Chainer
|
||||
|
||||
## Overview
|
||||
|
||||
The Epilogue Chainer provides a modular epilogue processing framework through scheduler-defined operation graphs.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Design Principle
|
||||
The chainer follows a **Scheduler-Graph-Node** architecture with shared context:
|
||||
- **Scheduler**: Defines operation graphs and creates a shared context
|
||||
- **Graph**: Composes multiple operations into sequential processing units
|
||||
- **Node**: Wraps individual epilogue operations with their arguments
|
||||
|
||||
### EpilogueChainer
|
||||
The `EpilogueChainer` struct serves as the modular epilogue processing facilitator. It delegates to schedulers for context creation and schedule generation, then processes the resulting operation graphs.
|
||||
|
||||
### EpilogueNode
|
||||
Individual epilogue operations are wrapped in `EpilogueNode` structures that capture required arguments at construction time and automatically forward them during processing. Supports both parameterized and parameter-free operations.
|
||||
|
||||
### EpilogueGraph
|
||||
The `EpilogueGraph` composes multiple nodes into sequential processing units that iterate over multiple accesses if needed, running all operations in order for each iteration.
|
||||
|
||||
## Files
|
||||
|
||||
### Core Infrastructure
|
||||
- `epilogue_chainer.hpp` - General chainer, node, and graph infrastructure
|
||||
- `common_epilogue_ops.hpp` - Epilogue operations usable with any epilogue type
|
||||
|
||||
### CShuffle Implementation
|
||||
- `cshuffle_epilogue_chainer_ops.hpp` - CShuffle-specific problem, context, and slice operations
|
||||
- `cshuffle_epilogue_schedule.hpp` - CShuffle scheduler with pre-built schedules
|
||||
|
||||
## Usage
|
||||
|
||||
### Common Operations (common_epilogue_ops.hpp)
|
||||
These operations work with any context that provides the standardized interface:
|
||||
- `ScaleScalarOp` - Scale working-tile by scalar values
|
||||
- `CastAndStoreToLdsOp<DstType>` - Cast working-tile and store to LDS
|
||||
- `LoadFromLdsOp<Pattern>` - Load output tile from LDS with sync
|
||||
- `ElementwiseOp<Func, NumAux>` - Apply elementwise operation with auxiliary tensors
|
||||
- `StoreOp<MemOp>` - Store output tile to global memory
|
||||
- `MoveWindowsOp<SFC, NumAux>` - Advance windows to next position
|
||||
|
||||
### CShuffle-Specific Operations (cshuffle_epilogue_chainer_ops.hpp)
|
||||
These operations are specific to CShuffle epilogue:
|
||||
- `CShuffleSliceOp` - Slice accumulator tile based on distribution
|
||||
- `CShuffleScaleWindowOp` - Scale using tensor windows with shuffle distribution
|
||||
|
||||
### Context Interface
|
||||
Operations communicate through a shared context with standardized members:
|
||||
- `working_tile`: Tile for intermediate computations
|
||||
- `out_tile`: Output tile
|
||||
- `aux_windows`: Tuple of auxiliary tensor windows
|
||||
- `lds_write_window`: Window for writing to LDS
|
||||
- `lds_read_window`: Window for reading from LDS
|
||||
|
||||
### Schedule Tags
|
||||
- `DefaultScheduleTag` - Standard: Slice → CastStore → Load → ApplyD → Store → Move
|
||||
- `RowColQuantScheduleTag` - With window scaling
|
||||
- `TensorQuantScheduleTag` - With scalar scaling
|
||||
208
include/ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp
Normal file
208
include/ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp
Normal file
@@ -0,0 +1,208 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
/// @file common_epilogue_ops.hpp
|
||||
/// @brief Reusable simple epilogue operations which might be used to compose more complex one.
|
||||
///
|
||||
///
|
||||
/// @par Context Interface
|
||||
/// Operations expect the context to provide:
|
||||
/// - working_tile: Tile for intermediate computations
|
||||
/// - out_tile: Output tile for final results
|
||||
/// - aux_windows: Tuple of auxiliary tensor windows (e.g., D tensors)
|
||||
/// - lds_write_window: Window for writing to LDS (if using LDS)
|
||||
/// - lds_read_window: Window for reading from LDS (if using LDS)
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Scale working tile by scalar values
|
||||
///
|
||||
/// @par Context Requirements
|
||||
/// working_tile: Tile to scale
|
||||
struct ScaleScalarOp
|
||||
{
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename IAccess,
|
||||
typename Context,
|
||||
typename ScaleA,
|
||||
typename ScaleB>
|
||||
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
|
||||
[[maybe_unused]] const AccTile& acc_tile,
|
||||
[[maybe_unused]] const AuxWindows& aux_windows,
|
||||
[[maybe_unused]] void* p_smem,
|
||||
[[maybe_unused]] IAccess iAccess,
|
||||
Context& context,
|
||||
const ScaleA& scale_a,
|
||||
const ScaleB& scale_b)
|
||||
{
|
||||
tile_elementwise_inout([&](auto& elem) { elem = elem * scale_a * scale_b; },
|
||||
context.working_tile);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Cast working tile and store to LDS
|
||||
///
|
||||
/// @tparam DataType Target data type for casting
|
||||
///
|
||||
/// @par Context Requirements
|
||||
/// working_tile: Tile to cast
|
||||
/// lds_write_window: Window for writing to LDS
|
||||
template <typename DataType>
|
||||
struct CastAndStoreToLdsOp
|
||||
{
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename IAccess,
|
||||
typename Context>
|
||||
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
|
||||
[[maybe_unused]] const AccTile& acc_tile,
|
||||
[[maybe_unused]] const AuxWindows& aux_windows,
|
||||
[[maybe_unused]] void* p_smem,
|
||||
[[maybe_unused]] IAccess iAccess,
|
||||
Context& context)
|
||||
{
|
||||
const auto casted_tile = cast_tile<DataType>(context.working_tile);
|
||||
store_tile(context.lds_write_window, casted_tile);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Load output tile from LDS with synchronization
|
||||
///
|
||||
/// @tparam TileEncodingPattern Pattern for tile distribution
|
||||
///
|
||||
/// @par Context Requirements
|
||||
/// lds_read_window: Window for reading from LDS
|
||||
/// out_tile: Destination for loaded tile
|
||||
template <typename TileEncodingPattern>
|
||||
struct LoadFromLdsOp
|
||||
{
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename IAccess,
|
||||
typename Context>
|
||||
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
|
||||
[[maybe_unused]] const AccTile& acc_tile,
|
||||
[[maybe_unused]] const AuxWindows& aux_windows,
|
||||
[[maybe_unused]] void* p_smem,
|
||||
[[maybe_unused]] IAccess iAccess,
|
||||
Context& context)
|
||||
{
|
||||
constexpr auto tile_distribution = TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
block_sync_lds();
|
||||
context.out_tile = load_tile(make_tile_window(context.lds_read_window, tile_distribution));
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Apply elementwise operation with auxiliary tensors
|
||||
///
|
||||
/// @tparam Elementwise Elementwise functor type
|
||||
/// @tparam NumAux Number of auxiliary tensors to load and apply
|
||||
///
|
||||
/// @par Context Requirements
|
||||
/// out_tile: In/out tile for elementwise operation
|
||||
/// aux_windows: Tuple of auxiliary tensor windows
|
||||
template <typename Elementwise, index_t NumAux>
|
||||
struct ElementwiseOp
|
||||
{
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename IAccess,
|
||||
typename Context>
|
||||
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
|
||||
[[maybe_unused]] const AccTile& acc_tile,
|
||||
[[maybe_unused]] const AuxWindows& aux_windows,
|
||||
[[maybe_unused]] void* p_smem,
|
||||
[[maybe_unused]] IAccess iAccess,
|
||||
Context& context)
|
||||
{
|
||||
const auto aux_tiles = generate_tuple(
|
||||
[&](auto idx) { return load_tile(context.aux_windows[idx]); }, number<NumAux>{});
|
||||
|
||||
const auto tiles = concat_tuple_of_reference(
|
||||
tie(context.out_tile, context.out_tile),
|
||||
generate_tie([&](auto idx) -> const auto& { return aux_tiles[idx]; },
|
||||
number<NumAux>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(Elementwise{}, tiles);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Store output tile to global memory
|
||||
///
|
||||
/// @tparam MemOp Memory operation type (set or atomic_add)
|
||||
///
|
||||
/// @par Context Requirements
|
||||
/// out_tile: Tile to store
|
||||
template <memory_operation_enum MemOp>
|
||||
struct StoreOp
|
||||
{
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename IAccess,
|
||||
typename Context>
|
||||
CK_TILE_DEVICE void operator()(OutWindow& out_window,
|
||||
[[maybe_unused]] const AccTile& acc_tile,
|
||||
[[maybe_unused]] const AuxWindows& aux_windows,
|
||||
[[maybe_unused]] void* p_smem,
|
||||
[[maybe_unused]] IAccess iAccess,
|
||||
Context& context)
|
||||
{
|
||||
if constexpr(MemOp == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_window, context.out_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_window, context.out_tile);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Move output and auxiliary windows by step from space-filling curve
|
||||
///
|
||||
/// @tparam SFC Space filling curve type providing step computation
|
||||
/// @tparam NumAux Number of auxiliary windows to move
|
||||
///
|
||||
/// @par Context Requirements
|
||||
/// aux_windows: Tuple of windows to move
|
||||
template <typename SFC, index_t NumAux>
|
||||
struct MoveWindowsOp
|
||||
{
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename IAccess,
|
||||
typename Context>
|
||||
CK_TILE_DEVICE void operator()(OutWindow& out_window,
|
||||
[[maybe_unused]] const AccTile& acc_tile,
|
||||
[[maybe_unused]] const AuxWindows& aux_windows,
|
||||
[[maybe_unused]] void* p_smem,
|
||||
IAccess iAccess,
|
||||
Context& context)
|
||||
{
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
|
||||
move_tile_window(out_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
|
||||
static_for<0, NumAux, 1>{}([&](auto idx) {
|
||||
move_tile_window(context.aux_windows[idx],
|
||||
{step.at(number<0>{}), step.at(number<1>{})});
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,512 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#include <optional>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// CShuffle-specific epilogue operations
|
||||
// These operations are specific to CShuffle epilogue due to its unique.
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
/// @brief Slice accumulator tile for CShuffle epilogue
|
||||
///
|
||||
/// @par Purpose
|
||||
/// Extracts a portion of the accumulator tile into the working tile
|
||||
/// based on the current iteration index. This is CShuffle-specific.
|
||||
///
|
||||
/// @tparam SFC Space filling curve type
|
||||
/// @tparam CWarpDstr Warp distribution type
|
||||
/// @tparam NumMXdlPerWavePerShuffle XDL tiles in M per wave per shuffle
|
||||
/// @tparam NumNXdlPerWavePerShuffle XDL tiles in N per wave per shuffle
|
||||
/// @tparam MPerIterShuffle M elements per shuffle iteration
|
||||
/// @tparam NPerIterShuffle N elements per shuffle iteration
|
||||
template <typename SFC,
|
||||
typename CWarpDstr,
|
||||
index_t NumMXdlPerWavePerShuffle,
|
||||
index_t NumNXdlPerWavePerShuffle,
|
||||
index_t MPerIterShuffle,
|
||||
index_t NPerIterShuffle>
|
||||
struct CShuffleSliceOp
|
||||
{
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename IAccess,
|
||||
typename Context>
|
||||
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
|
||||
const AccTile& acc_tile,
|
||||
[[maybe_unused]] const AuxWindows& aux_windows,
|
||||
[[maybe_unused]] void* p_smem,
|
||||
IAccess iAccess,
|
||||
Context& context)
|
||||
{
|
||||
constexpr auto idx_start = SFC::get_index(iAccess);
|
||||
constexpr auto m_iter = number<idx_start.at(number<0>{}) / MPerIterShuffle>{};
|
||||
constexpr auto n_iter = number<idx_start.at(number<1>{}) / NPerIterShuffle>{};
|
||||
|
||||
constexpr auto warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
context.working_tile.get_thread_buffer() = acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<m_iter * NumMXdlPerWavePerShuffle, n_iter * NumNXdlPerWavePerShuffle>{},
|
||||
warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
warp_y_lengths));
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Scale working tile using tensor windows (CShuffle-specific)
|
||||
///
|
||||
/// @par Purpose
|
||||
/// Scales the working tile using row and column scale tensors.
|
||||
/// CShuffle-specific because it creates scale windows from the
|
||||
/// working tile's distribution and handles window movement.
|
||||
///
|
||||
/// @tparam SFC Space filling curve type
|
||||
template <typename SFC>
|
||||
struct CShuffleScaleWindowOp
|
||||
{
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename IAccess,
|
||||
typename Context,
|
||||
typename ScaleRowTensor,
|
||||
typename ScaleColTensor>
|
||||
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
|
||||
[[maybe_unused]] const AccTile& acc_tile,
|
||||
[[maybe_unused]] const AuxWindows& aux_windows,
|
||||
[[maybe_unused]] void* p_smem,
|
||||
IAccess iAccess,
|
||||
Context& context,
|
||||
const ScaleRowTensor& scale_row_tensor,
|
||||
const ScaleColTensor& scale_col_tensor)
|
||||
{
|
||||
auto scale_row_window =
|
||||
make_tile_window(scale_row_tensor, context.working_tile.get_tile_distribution());
|
||||
auto scale_col_window =
|
||||
make_tile_window(scale_col_tensor, context.working_tile.get_tile_distribution());
|
||||
|
||||
const auto scale_row_tile = load_tile(scale_row_window);
|
||||
const auto scale_col_tile = load_tile(scale_col_window);
|
||||
|
||||
tile_elementwise_inout(element_wise::MultiDMultiply{},
|
||||
context.working_tile,
|
||||
context.working_tile,
|
||||
scale_row_tile,
|
||||
scale_col_tile);
|
||||
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(number<iAccess>{});
|
||||
move_tile_window(scale_row_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
move_tile_window(scale_col_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// CShuffle problem and base operation definitions
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
/// @brief Problem configuration for CShuffle epilogue chainer operations
|
||||
/// @note Mirrors CShuffleEpilogueProblem but uses AsDataType/BsDataType for tuple support
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename DsDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename CDElementwise_,
|
||||
index_t kM_,
|
||||
index_t kN_,
|
||||
index_t MWave_,
|
||||
index_t NWave_,
|
||||
index_t MPerXdl_,
|
||||
index_t NPerXdl_,
|
||||
index_t KPerXdl_,
|
||||
bool isCTransposed_,
|
||||
memory_operation_enum MemoryOperation_,
|
||||
index_t kNumWaveGroups_ = 1,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1,
|
||||
bool TiledMMAPermuteN_ = false,
|
||||
index_t BlockedXDLN_PerWarp_ = 1>
|
||||
struct CShuffleEpilogueChainProblem
|
||||
{
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using DsDataType = remove_cvref_t<DsDataType_>;
|
||||
using DsLayout = remove_cvref_t<DsLayout_>;
|
||||
using ELayout = remove_cvref_t<ELayout_>;
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
static constexpr index_t MWave = MWave_;
|
||||
static constexpr index_t NWave = NWave_;
|
||||
static constexpr index_t MPerXdl = MPerXdl_;
|
||||
static constexpr index_t NPerXdl = NPerXdl_;
|
||||
static constexpr index_t KPerXdl = KPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
|
||||
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct CShuffleEpilogueChainBaseOp
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
|
||||
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
|
||||
|
||||
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
|
||||
|
||||
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<AsDataType>,
|
||||
remove_cvref_t<tuple<AsDataType>>>;
|
||||
|
||||
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<BsDataType>,
|
||||
remove_cvref_t<tuple<BsDataType>>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
|
||||
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using ELayout = remove_cvref_t<typename Problem::ELayout>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr index_t MWave = Problem::MWave;
|
||||
static constexpr index_t NWave = Problem::NWave;
|
||||
static constexpr index_t MPerXdl = Problem::MPerXdl;
|
||||
static constexpr index_t NPerXdl = Problem::NPerXdl;
|
||||
static constexpr index_t KPerXdl = Problem::KPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t NumDTensor = Problem::NumDTensor;
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
|
||||
/**
|
||||
* @brief Get the vector store size for C tensor.
|
||||
*
|
||||
* @note The vector store size for output C tensor would depend on multiple factors
|
||||
* like its data layout and warp gemm C transposition. In general it would
|
||||
* be the number of consecutive elements in contiguous C dimension hold by
|
||||
* single thread.
|
||||
*
|
||||
* @return The vector store size for C tensor.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
|
||||
{
|
||||
if constexpr(FixedVectorSize)
|
||||
{
|
||||
return VectorSizeC;
|
||||
}
|
||||
constexpr index_t max_vector_size = 16;
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(NPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(ODataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(MPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(ODataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported ELayout!");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the vector store size for Di tensor.
|
||||
*
|
||||
* @return The vector store size for Di tensor.
|
||||
*/
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> index)
|
||||
{
|
||||
constexpr index_t max_vector_size = 16;
|
||||
using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(NPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(DiDataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(MPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(DiDataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported DLayout!");
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Shuffle tile configuration parameters
|
||||
*
|
||||
* @details These parameters control the number of XDL tiles processed per wave in each shuffle
|
||||
* iteration:
|
||||
* - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave
|
||||
* - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave
|
||||
*/
|
||||
static constexpr auto shuffle_tile_tuple = [] {
|
||||
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
|
||||
if constexpr(elem_per_thread >= GetVectorSizeC())
|
||||
{
|
||||
return std::make_tuple(1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
|
||||
(kMPerBlock % num_xdl_shuffles == 0),
|
||||
"kMPerBlock must be divisible by MPerXdl*MWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogueStageBase");
|
||||
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
|
||||
(kNPerBlock % num_xdl_shuffles == 0),
|
||||
"kNPerBlock must be divisible by NPerXdl*NWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogueStageBase");
|
||||
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
|
||||
}
|
||||
}
|
||||
}();
|
||||
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
|
||||
static constexpr index_t NumNXdlPerWavePerShuffle =
|
||||
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
|
||||
|
||||
static constexpr auto MNPerIterationShuffle = [] {
|
||||
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
|
||||
constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
|
||||
if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
|
||||
return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
|
||||
else
|
||||
return std::make_tuple(m_val, n_val);
|
||||
}();
|
||||
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
|
||||
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
|
||||
|
||||
using WG = WarpGemmDispatcher<ATypeToUse,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
KPerXdl,
|
||||
isCTransposed>;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
|
||||
{
|
||||
// N is contiguous dimension
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<NPerIterationShuffle>{}, number<1>{}));
|
||||
}
|
||||
// M is contiguous dimension
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<1>{}, number<MPerIterationShuffle>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported ELayout!");
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
|
||||
{
|
||||
constexpr auto block_outer_dstr_encoding = [] {
|
||||
if constexpr(BlockedXDLN_PerWarp == 1)
|
||||
{
|
||||
return tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<NumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
|
||||
// BlockedLayout
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{};
|
||||
}
|
||||
}();
|
||||
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
|
||||
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
|
||||
|
||||
return block_dstr_encoding;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
|
||||
}
|
||||
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
|
||||
/// @brief Context structure for CShuffle epilogue operations
|
||||
///
|
||||
/// @par Purpose
|
||||
/// The context serves as a shared workspace that maintains intermediate results
|
||||
/// and resources across multiple epilogue operations. It eliminates the need for
|
||||
/// operations to recreate shared data structures and enables data flow
|
||||
/// through the operation graph.
|
||||
///
|
||||
/// @par Standardized Interface
|
||||
/// Uses standardized member names so common operations can work with this context:
|
||||
/// - working_tile: Intermediate tile for shuffle operations
|
||||
/// - out_tile: Output tile for final results
|
||||
/// - aux_windows: Auxiliary tensor windows (D tensors)
|
||||
/// - lds_write_window: Window for writing to LDS
|
||||
/// - lds_read_window: Window for reading from LDS
|
||||
template <typename WorkingTileType,
|
||||
typename LdsBlockType,
|
||||
typename LdsWriteWindowType,
|
||||
typename LdsReadWindowType,
|
||||
typename AuxWindowsType,
|
||||
typename OutTileType>
|
||||
struct CShuffleContext
|
||||
{
|
||||
WorkingTileType working_tile; // Working tile for shuffle operations
|
||||
LdsBlockType lds_block; // LDS block view
|
||||
LdsWriteWindowType lds_write_window; // Window for writing to LDS
|
||||
LdsReadWindowType lds_read_window; // Window for reading from LDS
|
||||
AuxWindowsType aux_windows; // Auxiliary tensor windows (D tensors)
|
||||
OutTileType out_tile; // Output tile
|
||||
};
|
||||
|
||||
template <typename OutDramWindow, typename AccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto operator()([[maybe_unused]] OutDramWindow& out_window,
|
||||
[[maybe_unused]] const AccTile& acc_tile,
|
||||
const DsDramWindows& ds_windows,
|
||||
void* p_smem)
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
"Currently, the CShuffleEpilogueStageBase only supports the Row Major Output layout");
|
||||
|
||||
constexpr auto working_tile_distr =
|
||||
make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
auto working_tile = make_static_distributed_tensor<AccDataType>(working_tile_distr);
|
||||
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
auto lds_block = make_tensor_view<address_space_enum::lds>(static_cast<ODataType*>(p_smem),
|
||||
lds_block_desc);
|
||||
|
||||
auto lds_write_window = make_tile_window(
|
||||
lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0},
|
||||
working_tile_distr);
|
||||
|
||||
auto lds_read_window = make_tile_window(
|
||||
lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
constexpr auto dram_tile_distribution =
|
||||
TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
auto aux_windows = generate_tuple(
|
||||
[&](auto idx) { return make_tile_window(ds_windows[idx], dram_tile_distribution); },
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto out_tile = load_tile(make_tile_window(lds_read_window, dram_tile_distribution));
|
||||
|
||||
using ContextType = CShuffleContext<decltype(working_tile),
|
||||
decltype(lds_block),
|
||||
decltype(lds_write_window),
|
||||
decltype(lds_read_window),
|
||||
decltype(aux_windows),
|
||||
decltype(out_tile)>;
|
||||
|
||||
ContextType context;
|
||||
context.working_tile = working_tile;
|
||||
context.lds_block = lds_block;
|
||||
context.lds_write_window = lds_write_window;
|
||||
context.lds_read_window = lds_read_window;
|
||||
context.aux_windows = aux_windows;
|
||||
context.out_tile = out_tile;
|
||||
|
||||
return context;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,129 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp"
|
||||
#include "ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp"
|
||||
#include "ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Schedule type tags for epilogue selection
|
||||
/// @par Purpose
|
||||
/// Each tag corresponds to a pre-built schedule, these are used to select a schedule
|
||||
|
||||
/// Standard epilogue schedule: Slice → CastStore → Load → ApplyD → Store → Move
|
||||
struct DefaultScheduleTag
|
||||
{
|
||||
};
|
||||
|
||||
/// RowCol quantization schedule: Slice → ScaleWindow → CastStore → Load → ApplyD → Store → Move
|
||||
struct RowColQuantScheduleTag
|
||||
{
|
||||
};
|
||||
|
||||
/// Tensor quantization schedule: Slice → ScaleScalar → CastStore → Load → ApplyD → Store → Move
|
||||
struct TensorQuantScheduleTag
|
||||
{
|
||||
};
|
||||
|
||||
/// @brief CShuffle epilogue scheduler providing pre-built schedules
|
||||
///
|
||||
/// @par Overview
|
||||
/// CshuffleEpilogueSchedule acts as the scheduler component for EpilogueChainer.
|
||||
/// It provides context creation and pre-built schedules. The scheduler
|
||||
/// uses tags to select/create appropriate epilogue schedule.
|
||||
///
|
||||
/// @tparam Problem The epilogue problem configuration
|
||||
/// @tparam ScheduleTag Tag selecting the epilogue schedule
|
||||
template <typename Problem, typename ScheduleTag = DefaultScheduleTag>
|
||||
struct CshuffleEpilogueSchedule
|
||||
{
|
||||
using ProblemType = Problem;
|
||||
using BaseOp = CShuffleEpilogueChainBaseOp<Problem>;
|
||||
|
||||
static constexpr index_t NumAccess = BaseOp::SFC::get_num_of_access();
|
||||
|
||||
/// @brief Create context for epilogue operations
|
||||
template <typename OutWindow, typename AccTile, typename AuxWindows>
|
||||
CK_TILE_DEVICE static auto create_context(OutWindow& out_window,
|
||||
const AccTile& acc_tile,
|
||||
const AuxWindows& aux_windows,
|
||||
void* p_smem)
|
||||
{
|
||||
return BaseOp{}(out_window, acc_tile, aux_windows, p_smem);
|
||||
}
|
||||
|
||||
/// @brief Make schedule based on compile-time tag selection
|
||||
template <typename... Args>
|
||||
CK_TILE_DEVICE static auto make_schedule(Args&&... args)
|
||||
{
|
||||
if constexpr(std::is_same_v<ScheduleTag, DefaultScheduleTag>)
|
||||
{
|
||||
// Standard epilogue
|
||||
// Schedule: Slice -> CastAndStoreLds -> Load -> ApplyD -> Store -> MoveWindows
|
||||
static_assert(sizeof...(args) == 0, "DefaultSchedule expects no arguments");
|
||||
return make_graph<NumAccess>(
|
||||
make_node<CShuffleSliceOp<typename BaseOp::SFC,
|
||||
typename BaseOp::CWarpDstr,
|
||||
BaseOp::NumMXdlPerWavePerShuffle,
|
||||
BaseOp::NumNXdlPerWavePerShuffle,
|
||||
BaseOp::MPerIterationShuffle,
|
||||
BaseOp::NPerIterationShuffle>>(),
|
||||
make_node<CastAndStoreToLdsOp<typename BaseOp::ODataType>>(),
|
||||
make_node<LoadFromLdsOp<typename BaseOp::TileEncodingPattern>>(),
|
||||
make_node<ElementwiseOp<typename Problem::CDElementwise, Problem::NumDTensor>>(),
|
||||
make_node<StoreOp<Problem::MemoryOperation>>(),
|
||||
make_node<MoveWindowsOp<typename BaseOp::SFC, Problem::NumDTensor>>());
|
||||
}
|
||||
else if constexpr(std::is_same_v<ScheduleTag, RowColQuantScheduleTag>)
|
||||
{
|
||||
// RowCol quantization schedule with tensor windows
|
||||
// Schedule: Slice -> ScaleWindow -> CastAndStoreLds -> Load -> ApplyD -> Store ->
|
||||
// MoveWindows
|
||||
static_assert(sizeof...(args) == 2,
|
||||
"RowColQuantSchedule requires exactly 2 scale tensor arguments");
|
||||
return make_graph<NumAccess>(
|
||||
make_node<CShuffleSliceOp<typename BaseOp::SFC,
|
||||
typename BaseOp::CWarpDstr,
|
||||
BaseOp::NumMXdlPerWavePerShuffle,
|
||||
BaseOp::NumNXdlPerWavePerShuffle,
|
||||
BaseOp::MPerIterationShuffle,
|
||||
BaseOp::NPerIterationShuffle>>(),
|
||||
make_node<CShuffleScaleWindowOp<typename BaseOp::SFC>>(std::forward<Args>(args)...),
|
||||
make_node<CastAndStoreToLdsOp<typename BaseOp::ODataType>>(),
|
||||
make_node<LoadFromLdsOp<typename BaseOp::TileEncodingPattern>>(),
|
||||
make_node<ElementwiseOp<typename Problem::CDElementwise, Problem::NumDTensor>>(),
|
||||
make_node<StoreOp<Problem::MemoryOperation>>(),
|
||||
make_node<MoveWindowsOp<typename BaseOp::SFC, Problem::NumDTensor>>());
|
||||
}
|
||||
else if constexpr(std::is_same_v<ScheduleTag, TensorQuantScheduleTag>)
|
||||
{
|
||||
// Tensor quantization schedule with scalar values
|
||||
// Schedule: Slice -> ScaleScalar -> CastAndStoreLds -> Load -> ApplyD -> Store ->
|
||||
// MoveWindows
|
||||
static_assert(sizeof...(args) == 2,
|
||||
"TensorQuantSchedule requires exactly 2 scalar arguments");
|
||||
return make_graph<NumAccess>(
|
||||
make_node<CShuffleSliceOp<typename BaseOp::SFC,
|
||||
typename BaseOp::CWarpDstr,
|
||||
BaseOp::NumMXdlPerWavePerShuffle,
|
||||
BaseOp::NumNXdlPerWavePerShuffle,
|
||||
BaseOp::MPerIterationShuffle,
|
||||
BaseOp::NPerIterationShuffle>>(),
|
||||
make_node<ScaleScalarOp>(std::forward<Args>(args)...),
|
||||
make_node<CastAndStoreToLdsOp<typename BaseOp::ODataType>>(),
|
||||
make_node<LoadFromLdsOp<typename BaseOp::TileEncodingPattern>>(),
|
||||
make_node<ElementwiseOp<typename Problem::CDElementwise, Problem::NumDTensor>>(),
|
||||
make_node<StoreOp<Problem::MemoryOperation>>(),
|
||||
make_node<MoveWindowsOp<typename BaseOp::SFC, Problem::NumDTensor>>());
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown schedule tag");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
213
include/ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp
Normal file
213
include/ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp
Normal file
@@ -0,0 +1,213 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Epilogue Chainer - Modular epilogue processing facilitator
|
||||
///
|
||||
/// @par Overview
|
||||
/// EpilogueChainer provides an interface for processing epilogue operations
|
||||
/// through schedules. The chainer uses decomposed epilogue operations, these are
|
||||
/// scheduled/sequenced by a Scheduler to form operation graphs.
|
||||
///
|
||||
/// @tparam Scheduler The schedule provider that defines epilogue operation graphs
|
||||
template <typename Scheduler>
|
||||
struct EpilogueChainer
|
||||
{
|
||||
using Problem = typename Scheduler::ProblemType;
|
||||
using BaseOp = typename Scheduler::BaseOp;
|
||||
|
||||
using ODataType = typename BaseOp::ODataType;
|
||||
using DsDataType = typename BaseOp::DsDataType;
|
||||
using DsLayout = typename BaseOp::DsLayout;
|
||||
using AccDataType = typename BaseOp::AccDataType;
|
||||
static constexpr auto MemoryOperation = BaseOp::MemoryOperation;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return BaseOp::GetSmemSize(); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
|
||||
{
|
||||
return BaseOp::GetVectorSizeC();
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> idx)
|
||||
{
|
||||
return BaseOp::GetVectorSizeD(idx);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
|
||||
{
|
||||
return BaseOp::MakeLdsDistributionEncode();
|
||||
}
|
||||
|
||||
/// @brief Process epilogue through scheduler-defined operation graph
|
||||
///
|
||||
/// @par Flow
|
||||
/// 1. Create shared context through scheduler
|
||||
/// 2. Generate operation schedule based on arguments
|
||||
/// 3. Run scheduled operations in sequence
|
||||
template <typename OutWindow, typename AccTile, typename AuxWindows, typename... Args>
|
||||
CK_TILE_DEVICE void operator()(OutWindow& out_window,
|
||||
const AccTile& acc_tile,
|
||||
const AuxWindows& aux_windows,
|
||||
void* p_smem,
|
||||
Args&&... args) const
|
||||
{
|
||||
// The context serves as a shared workspace that maintains intermediate results
|
||||
// and resources across multiple epilogue operations.
|
||||
auto context = Scheduler::create_context(out_window, acc_tile, aux_windows, p_smem);
|
||||
auto schedule = Scheduler::make_schedule(std::forward<Args>(args)...);
|
||||
schedule(out_window, acc_tile, aux_windows, p_smem, context);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Epilogue operation wrapper with arguments
|
||||
///
|
||||
/// @par Purpose
|
||||
/// EpilogueNode wraps individual epilogue operations with their required arguments,
|
||||
/// allowing them to be composed into operation graphs. Arguments are captured at construction
|
||||
/// time and automatically forwarded during processing.
|
||||
///
|
||||
/// @tparam EpilogueType Epilogue operation (e.g., SliceEpilogue, ScaleEpilogue)
|
||||
/// @tparam Args Types of arguments required by the epilogue operation
|
||||
template <typename EpilogueType, typename... Args>
|
||||
struct EpilogueNode
|
||||
{
|
||||
using Epilogue = EpilogueType;
|
||||
ck_tile::tuple<Args...> args;
|
||||
|
||||
constexpr EpilogueNode(Args... a) : args(a...) {}
|
||||
|
||||
/// @brief Process epilogue without iteration index
|
||||
template <typename OutWindow, typename AccTile, typename AuxWindows, typename Context>
|
||||
CK_TILE_DEVICE void operator()(OutWindow& out_window,
|
||||
const AccTile& acc_tile,
|
||||
const AuxWindows& aux_windows,
|
||||
void* p_smem,
|
||||
Context& context) const
|
||||
{
|
||||
ck_tile::apply(
|
||||
[&](auto&&... epilogue_args) {
|
||||
EpilogueType{}(out_window,
|
||||
acc_tile,
|
||||
aux_windows,
|
||||
p_smem,
|
||||
context,
|
||||
std::forward<decltype(epilogue_args)>(epilogue_args)...);
|
||||
},
|
||||
args);
|
||||
}
|
||||
|
||||
/// @brief Process epilogue with iteration index
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename Context,
|
||||
index_t I>
|
||||
CK_TILE_DEVICE void operator()(OutWindow& out_window,
|
||||
const AccTile& acc_tile,
|
||||
const AuxWindows& aux_windows,
|
||||
void* p_smem,
|
||||
Context& context,
|
||||
number<I> iAccess) const
|
||||
{
|
||||
ck_tile::apply(
|
||||
[&](auto&&... epilogue_args) {
|
||||
EpilogueType{}(out_window,
|
||||
acc_tile,
|
||||
aux_windows,
|
||||
p_smem,
|
||||
iAccess,
|
||||
context,
|
||||
std::forward<decltype(epilogue_args)>(epilogue_args)...);
|
||||
},
|
||||
args);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Specialization for epilogue operation wrapper with no arguments
|
||||
template <typename EpilogueType>
|
||||
struct EpilogueNode<EpilogueType>
|
||||
{
|
||||
using Epilogue = EpilogueType;
|
||||
ck_tile::tuple<> args;
|
||||
|
||||
constexpr EpilogueNode() = default;
|
||||
|
||||
template <typename OutWindow, typename AccTile, typename AuxWindows, typename Context>
|
||||
CK_TILE_DEVICE void operator()(OutWindow& out_window,
|
||||
const AccTile& acc_tile,
|
||||
const AuxWindows& aux_windows,
|
||||
void* p_smem,
|
||||
Context& context) const
|
||||
{
|
||||
EpilogueType{}(out_window, acc_tile, aux_windows, p_smem, context);
|
||||
}
|
||||
|
||||
template <typename OutWindow,
|
||||
typename AccTile,
|
||||
typename AuxWindows,
|
||||
typename Context,
|
||||
index_t I>
|
||||
CK_TILE_DEVICE void operator()(OutWindow& out_window,
|
||||
const AccTile& acc_tile,
|
||||
const AuxWindows& aux_windows,
|
||||
void* p_smem,
|
||||
Context& context,
|
||||
number<I> iAccess) const
|
||||
{
|
||||
EpilogueType{}(out_window, acc_tile, aux_windows, p_smem, iAccess, context);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Operation graph that sequentially composes multiple epilogue operations
|
||||
///
|
||||
/// @tparam Steps Number of iterations
|
||||
/// @tparam EpilogueTypes Types of epilogue nodes in the operation graph
|
||||
template <index_t Steps, typename... EpilogueTypes>
|
||||
struct EpilogueGraph
|
||||
{
|
||||
ck_tile::tuple<EpilogueTypes...> epilogues;
|
||||
|
||||
constexpr EpilogueGraph() = default;
|
||||
constexpr EpilogueGraph(EpilogueTypes... eps) : epilogues(eps...) {}
|
||||
|
||||
/// @brief Process all epilogues for each iteration in sequence
|
||||
template <typename OutWindow, typename AccTile, typename AuxWindows, typename Context>
|
||||
CK_TILE_DEVICE void operator()(OutWindow& out_window,
|
||||
const AccTile& acc_tile,
|
||||
const AuxWindows& aux_windows,
|
||||
void* p_smem,
|
||||
Context& context) const
|
||||
{
|
||||
// For each iteration, process all epilogues in order
|
||||
static_for<0, Steps, 1>{}([&](auto iAccess) {
|
||||
static_for<0, sizeof...(EpilogueTypes), 1>{}([&](auto I) {
|
||||
epilogues.template get<I.value>()(
|
||||
out_window, acc_tile, aux_windows, p_smem, context, iAccess);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief Helper function for creating epilogue nodes
|
||||
template <typename EpilogueType, typename... Args>
|
||||
constexpr auto make_node(Args... args)
|
||||
{
|
||||
return EpilogueNode<EpilogueType, Args...>{args...};
|
||||
}
|
||||
|
||||
/// @brief Helper function for creating operation graphs
|
||||
template <index_t Steps, typename... EpilogueTypes>
|
||||
constexpr auto make_graph(EpilogueTypes... epilogues)
|
||||
{
|
||||
return EpilogueGraph<Steps, EpilogueTypes...>{epilogues...};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
879
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
Normal file
879
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
Normal file
@@ -0,0 +1,879 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename DsDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename DsLayout_,
|
||||
typename ELayout_,
|
||||
typename CDElementwise_,
|
||||
index_t kM_,
|
||||
index_t kN_,
|
||||
index_t MWave_,
|
||||
index_t NWave_,
|
||||
index_t MPerXdl_,
|
||||
index_t NPerXdl_,
|
||||
index_t KPerXdl_,
|
||||
bool isCTransposed_,
|
||||
index_t kNumWaveGroups_ = 1,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1,
|
||||
bool TiledMMAPermuteN_ = false,
|
||||
index_t BlockedXDLN_PerWarp_ = 1, // The number of continuous xdl_output per warp
|
||||
bool DoubleSmemBuffer_ = false>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using DsDataType = remove_cvref_t<DsDataType_>;
|
||||
using DsLayout = remove_cvref_t<DsLayout_>;
|
||||
using ELayout = remove_cvref_t<ELayout_>;
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
static constexpr index_t MWave = MWave_;
|
||||
static constexpr index_t NWave = NWave_;
|
||||
static constexpr index_t MPerXdl = MPerXdl_;
|
||||
static constexpr index_t NPerXdl = NPerXdl_;
|
||||
static constexpr index_t KPerXdl = KPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct CShuffleEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
|
||||
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
|
||||
|
||||
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
|
||||
|
||||
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<AsDataType>,
|
||||
remove_cvref_t<tuple<AsDataType>>>;
|
||||
|
||||
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<BsDataType>,
|
||||
remove_cvref_t<tuple<BsDataType>>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
|
||||
|
||||
using ATypeToUse = std::conditional_t<std::is_same_v<ADataType, pk_int4_t> ||
|
||||
std::is_same_v<ADataType, pk_fp4_t>,
|
||||
BDataType,
|
||||
ADataType>;
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
|
||||
std::is_same_v<BDataType, pk_fp4_t> ||
|
||||
sizeof(BDataType) < sizeof(ADataType),
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
using ELayout = remove_cvref_t<typename Problem::ELayout>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr index_t MWave = Problem::MWave;
|
||||
static constexpr index_t NWave = Problem::NWave;
|
||||
static constexpr index_t MPerXdl = Problem::MPerXdl;
|
||||
static constexpr index_t NPerXdl = Problem::NPerXdl;
|
||||
static constexpr index_t KPerXdl = Problem::KPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
|
||||
static constexpr bool EightWave = (MWave * NWave == 8);
|
||||
static constexpr index_t BlockedXDLN_PerWarp =
|
||||
EightWave ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t NumDTensor = Problem::NumDTensor;
|
||||
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
|
||||
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
|
||||
|
||||
CDElementwise elfunc_;
|
||||
|
||||
CK_TILE_DEVICE CShuffleEpilogue(CDElementwise elfunc = CDElementwise{}) : elfunc_(elfunc) {};
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "CShuffleEpilogue",
|
||||
concat('x', MWave, NWave),
|
||||
concat('x', MPerXdl, NPerXdl, KPerXdl),
|
||||
VectorSizeC,
|
||||
isCTransposed ? "CTransposed" : "CNotTransposed");
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the vector store size for C tensor.
|
||||
*
|
||||
* @note The vector store size for output C tensor would depend on multiple factors
|
||||
* like its data layout and warp gemm C transposition. In general it would
|
||||
* be the number of consecutive elements in contiguous C dimension hold by
|
||||
* single thread.
|
||||
*
|
||||
* @return The vector store size for C tensor.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
|
||||
{
|
||||
if constexpr(FixedVectorSize)
|
||||
{
|
||||
return VectorSizeC;
|
||||
}
|
||||
constexpr index_t max_vector_size = 16;
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(NPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(ODataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(MPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(ODataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported ELayout!");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the vector store size for Di tensor.
|
||||
*
|
||||
* @return The vector store size for Di tensor.
|
||||
*/
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> index)
|
||||
{
|
||||
constexpr index_t max_vector_size = 16;
|
||||
using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(NPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(DiDataType)));
|
||||
}
|
||||
else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return std::min(static_cast<int>(MPerIteration),
|
||||
static_cast<int>(max_vector_size / sizeof(DiDataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported DLayout!");
|
||||
}
|
||||
return max_vector_size / sizeof(DiDataType);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Shuffle tile configuration parameters check and aligment
|
||||
*
|
||||
* @details Return tuple(1, 1) if shuffle_tile values are too large for SMEM.
|
||||
*/
|
||||
template <index_t m_shuffle_tile, index_t n_shuffle_tile>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto AlignShuffleTileWithSmem()
|
||||
{
|
||||
constexpr index_t m_val = MPerXdl * MWave * m_shuffle_tile;
|
||||
constexpr index_t n_val = NPerXdl * NWave * n_shuffle_tile;
|
||||
|
||||
constexpr auto shuffle_tile =
|
||||
m_val * n_val * sizeof(ODataType) > get_smem_capacity() || DoubleSmemBuffer
|
||||
? std::make_tuple(1, 1)
|
||||
: std::make_tuple(m_shuffle_tile, n_shuffle_tile);
|
||||
|
||||
return shuffle_tile;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Shuffle tile configuration parameters
|
||||
*
|
||||
* @details These parameters control the number of XDL tiles processed per wave in each shuffle
|
||||
* iteration:
|
||||
* - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave
|
||||
* - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave
|
||||
*/
|
||||
static constexpr auto shuffle_tile_tuple = [] {
|
||||
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
|
||||
if constexpr(elem_per_thread <= GetVectorSizeC())
|
||||
{
|
||||
return std::make_tuple(1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t num_xdl_shuffles = elem_per_thread / GetVectorSizeC();
|
||||
static_assert(elem_per_thread % GetVectorSizeC() == 0);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
|
||||
(kMPerBlock % num_xdl_shuffles == 0),
|
||||
"kMPerBlock must be divisible by MPerXdl*MWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogue");
|
||||
return AlignShuffleTileWithSmem<min(num_xdl_shuffles,
|
||||
kMPerBlock / (MPerXdl * MWave)),
|
||||
1>();
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
|
||||
(kNPerBlock % num_xdl_shuffles == 0),
|
||||
"kNPerBlock must be divisible by NPerXdl*NWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogue");
|
||||
return AlignShuffleTileWithSmem<1,
|
||||
min(num_xdl_shuffles,
|
||||
kNPerBlock / (NPerXdl * NWave))>();
|
||||
}
|
||||
}
|
||||
}();
|
||||
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
|
||||
static constexpr index_t NumNXdlPerWavePerShuffle =
|
||||
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
|
||||
|
||||
static constexpr auto MNPerIterationShuffle = [] {
|
||||
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
|
||||
constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
|
||||
if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
|
||||
return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
|
||||
else
|
||||
return std::make_tuple(m_val, n_val);
|
||||
}();
|
||||
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
|
||||
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
|
||||
|
||||
using WG = WarpGemmDispatcher<ATypeToUse,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
KPerXdl,
|
||||
isCTransposed>;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
|
||||
{
|
||||
constexpr auto DataTypeSize = sizeof(ODataType);
|
||||
constexpr index_t VectorLen = GetVectorSizeC();
|
||||
constexpr index_t banks = get_n_lds_banks();
|
||||
|
||||
constexpr index_t BytesPerBank = 4;
|
||||
|
||||
// N is contiguous dimension
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t MLdsLayerRequired =
|
||||
banks * BytesPerBank / NPerIterationShuffle / DataTypeSize;
|
||||
constexpr auto MLdsLayer = max(1, MLdsLayerRequired);
|
||||
|
||||
constexpr index_t BaseStrideElems = NPerIterationShuffle * MLdsLayer;
|
||||
static_assert((BaseStrideElems * DataTypeSize) % BytesPerBank == 0,
|
||||
"LDS row stride must be 4B-aligned for bank-word padding logic");
|
||||
// calculate how many elements to pad to avoid bank conflict
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t ElemsPer4B = BytesPerBank / ck_tile::gcd(BytesPerBank, DataTypeSize);
|
||||
constexpr auto ToWords = [](index_t elems) constexpr {
|
||||
return (elems * DataTypeSize) / BytesPerBank;
|
||||
};
|
||||
constexpr index_t BaseWords = ToWords(BaseStrideElems);
|
||||
constexpr index_t PadWords = ((BaseWords % 2) == 0) ? 1 : 0;
|
||||
constexpr auto PaddingAmount = PadWords * ElemsPer4B;
|
||||
#else
|
||||
constexpr auto PaddingAmount = 0;
|
||||
#endif
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<MPerIterationShuffle / MLdsLayer>{},
|
||||
number<NPerIterationShuffle / VectorLen * MLdsLayer>{},
|
||||
number<VectorLen>{}),
|
||||
make_tuple(number<NPerIterationShuffle * MLdsLayer + PaddingAmount>{},
|
||||
number<VectorLen>{},
|
||||
number<1>{}),
|
||||
number<VectorLen>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc_1 = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<MPerIterationShuffle / MLdsLayer>{}),
|
||||
make_unmerge_transform(make_tuple(
|
||||
number<MLdsLayer>{}, number<NPerIterationShuffle / VectorLen>{})),
|
||||
make_pass_through_transform(number<VectorLen>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto lds_block_desc = transform_tensor_descriptor(
|
||||
lds_block_desc_1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<MPerIterationShuffle / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<NPerIterationShuffle / VectorLen>{}, number<VectorLen>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_block_desc;
|
||||
}
|
||||
// M is contiguous dimension
|
||||
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t NLdsLayerRequired =
|
||||
get_n_lds_banks() * BytesPerBank / MPerIterationShuffle / DataTypeSize;
|
||||
constexpr auto NLdsLayer = max(1, NLdsLayerRequired);
|
||||
|
||||
constexpr index_t BaseStrideElems = MPerIterationShuffle * NLdsLayer;
|
||||
|
||||
static_assert((BaseStrideElems * DataTypeSize) % BytesPerBank == 0,
|
||||
"LDS row stride must be 4B-aligned for bank-word padding logic");
|
||||
|
||||
#if defined(__gfx950__)
|
||||
constexpr index_t ElemsPer4B = BytesPerBank / ck_tile::gcd(BytesPerBank, DataTypeSize);
|
||||
constexpr auto ToWords = [](index_t elems) constexpr {
|
||||
return (elems * DataTypeSize) / BytesPerBank;
|
||||
};
|
||||
constexpr index_t BaseWords = ToWords(BaseStrideElems);
|
||||
constexpr index_t PadWords = ((BaseWords % 2) == 0) ? 1 : 0;
|
||||
constexpr auto PaddingAmount = PadWords * ElemsPer4B;
|
||||
#else
|
||||
constexpr auto PaddingAmount = 0;
|
||||
#endif
|
||||
|
||||
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NPerIterationShuffle / NLdsLayer>{},
|
||||
number<MPerIterationShuffle / VectorLen * NLdsLayer>{},
|
||||
number<VectorLen>{}),
|
||||
make_tuple(number<MPerIterationShuffle * NLdsLayer + PaddingAmount>{},
|
||||
number<VectorLen>{},
|
||||
number<1>{}),
|
||||
number<VectorLen>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto lds_block_desc_1 = transform_tensor_descriptor(
|
||||
lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<NPerIterationShuffle / NLdsLayer>{}),
|
||||
make_unmerge_transform(make_tuple(
|
||||
number<NLdsLayer>{}, number<MPerIterationShuffle / VectorLen>{})),
|
||||
make_pass_through_transform(number<VectorLen>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto lds_block_desc = transform_tensor_descriptor(
|
||||
lds_block_desc_1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<NPerIterationShuffle / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<MPerIterationShuffle / VectorLen>{}, number<VectorLen>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return lds_block_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported ELayout!");
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
|
||||
{
|
||||
constexpr auto block_outer_dstr_encoding = [] {
|
||||
if constexpr(BlockedXDLN_PerWarp == 1)
|
||||
{
|
||||
return tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<NumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
constexpr auto is_950 = true;
|
||||
#else
|
||||
constexpr auto is_950 = false;
|
||||
#endif
|
||||
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
|
||||
// BlockedLayout
|
||||
// this branch is for original a16w4
|
||||
if constexpr(is_950 || is_any_of<ADataType, pk_int4_t, pk_fp4_t>::value ||
|
||||
is_any_of<BDataType, pk_int4_t, pk_fp4_t>::value)
|
||||
{
|
||||
if constexpr(EightWave)
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
sequence<RakedXDLN_PerWarp, BlockedXDLN_PerWarp, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 1>>{};
|
||||
}
|
||||
}
|
||||
}();
|
||||
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
|
||||
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
|
||||
|
||||
return block_dstr_encoding;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
return lds_block_desc.get_element_space_size() * sizeof(ODataType);
|
||||
}
|
||||
|
||||
template <index_t iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
|
||||
CK_TILE_DEVICE void
|
||||
scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window)
|
||||
{
|
||||
// Check if scales are EmptyScale first (no scaling needed)
|
||||
if constexpr(std::is_same_v<ScaleM, EmptyScale> && std::is_same_v<ScaleN, EmptyScale>)
|
||||
{
|
||||
// No scaling needed - this is a no-op
|
||||
}
|
||||
// Check if scales are scalar AccDataType
|
||||
else if constexpr(std::is_same_v<ScaleM, AccDataType> &&
|
||||
std::is_same_v<ScaleN, AccDataType>)
|
||||
{
|
||||
// Handle scalar scales
|
||||
const AccDataType scale_m = scale_m_window;
|
||||
const AccDataType scale_n = scale_n_window;
|
||||
tile_elementwise_inout([&](auto& element) { element = element * scale_m * scale_n; },
|
||||
lds_tile);
|
||||
}
|
||||
// Otherwise, assume they are tile windows that can be loaded
|
||||
else
|
||||
{
|
||||
// Load tiles
|
||||
const auto scale_m_tile = load_tile(scale_m_window);
|
||||
const auto scale_n_tile = load_tile(scale_n_window);
|
||||
|
||||
// Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
|
||||
tile_elementwise_inout(
|
||||
element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
|
||||
|
||||
// Move scale windows
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(number<iAccess>{});
|
||||
|
||||
move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t iAccess, typename OAccTile, typename LdsTile>
|
||||
CK_TILE_DEVICE void slice_acc_tile(const OAccTile& o_acc_tile, LdsTile& lds_tile)
|
||||
{
|
||||
constexpr auto idx_y_start = SFC::get_index(number<iAccess>{});
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
}
|
||||
|
||||
template <typename LdsTile, typename InLdsWindow>
|
||||
CK_TILE_DEVICE void cast_lds_tile(LdsTile& lds_tile, InLdsWindow& in_lds_window)
|
||||
{
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
|
||||
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
}
|
||||
|
||||
template <typename DramWindows, typename COutTensor>
|
||||
CK_TILE_DEVICE void apply_d_tensors(DramWindows& d_dram_windows, COutTensor& c_out_tensor)
|
||||
{
|
||||
const auto ds_tensor = generate_tuple(
|
||||
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
|
||||
|
||||
const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
tie(c_out_tensor, c_out_tensor),
|
||||
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
|
||||
number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(elfunc_, c_ds_tiles);
|
||||
}
|
||||
|
||||
template <typename OutDramWindow, typename COutTensor>
|
||||
CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window,
|
||||
const COutTensor& c_out_tensor)
|
||||
{
|
||||
if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
|
||||
memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Move both the output and D tensors windows for the next access.
|
||||
*/
|
||||
template <index_t iAccess, typename OutDramWindow, typename DDramWindows>
|
||||
CK_TILE_DEVICE void move_windows(OutDramWindow& out_dram_window, DDramWindows& d_dram_windows)
|
||||
{
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(number<iAccess>{});
|
||||
|
||||
// move the output dram window
|
||||
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
|
||||
// move windows for each of the D matrices (inputs for element-wise)
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx], {step.at(number<0>{}), step.at(number<1>{})});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Check if there would be nicer ways to overload rather than with EmptyScale or nullptr_t
|
||||
struct EmptyScale
|
||||
{
|
||||
};
|
||||
|
||||
template <typename, typename = void>
|
||||
struct ScaleDataType
|
||||
{
|
||||
using DataType = float;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ScaleDataType<T, std::void_t<typename T::DataType>>
|
||||
{
|
||||
using DataType = typename T::DataType;
|
||||
};
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
typename ScaleM = EmptyScale,
|
||||
typename ScaleN = EmptyScale,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* /* p_smem */,
|
||||
const ScaleM& scale_m = {},
|
||||
const ScaleN& scale_n = {})
|
||||
{
|
||||
static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size();
|
||||
|
||||
static_assert(MPerXdl % RowsPerLane == 0,
|
||||
"CShuffle (permuteN): MPerXdl must be divisible by per-lane row count.");
|
||||
constexpr int kM0 = MWave;
|
||||
constexpr int kM2 = RowsPerLane;
|
||||
constexpr int kM1 = MPerXdl / kM2;
|
||||
|
||||
constexpr int kN0 = NWave;
|
||||
constexpr int kN1 = NPerXdl;
|
||||
constexpr int kN2 = NRepeat;
|
||||
|
||||
using IntrThreadShuffleEncode =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>;
|
||||
constexpr auto dram_tile_distribution =
|
||||
make_static_tile_distribution(IntrThreadShuffleEncode{});
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
|
||||
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
|
||||
|
||||
// Optional scales (must share the same distribution to match per-thread indexing)
|
||||
constexpr bool has_scales =
|
||||
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
|
||||
constexpr bool has_scalar_scales =
|
||||
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
|
||||
|
||||
// Tiles to hold row/col scales when present
|
||||
using SMType = typename ScaleDataType<ScaleM>::DataType;
|
||||
using SNType = typename ScaleDataType<ScaleN>::DataType;
|
||||
|
||||
auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
|
||||
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
|
||||
|
||||
// Build windows only if non-scalar scales are provided
|
||||
auto scale_m_window = [&]() {
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
return make_tile_window(scale_m, dram_tile_distribution);
|
||||
}
|
||||
else
|
||||
{
|
||||
return EmptyScale{};
|
||||
}
|
||||
}();
|
||||
auto scale_n_window = [&]() {
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
return make_tile_window(scale_n, dram_tile_distribution);
|
||||
}
|
||||
else
|
||||
{
|
||||
return EmptyScale{};
|
||||
}
|
||||
}();
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
// Slice accumulators for this M repeat into the permuted layout
|
||||
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
// If non-scalar scales provided, load them with identical distribution
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
|
||||
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
|
||||
}
|
||||
|
||||
// Pack 4 “rows per lane” as you already do
|
||||
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
|
||||
// source indices in shuffle_acc: (n_idx * product(Y) + row)
|
||||
const index_t plane = c_warp_y_lengths.product();
|
||||
|
||||
// local lambda to fuse scale (if present) and convert
|
||||
static_for<0, kM2, 1>{}([&](auto m_lane) {
|
||||
const int src = n_idx * plane + m_lane; // source row in this N-plane
|
||||
const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output
|
||||
AccDataType v = shuffle_acc.get_thread_buffer()[src];
|
||||
|
||||
if constexpr(has_scalar_scales)
|
||||
{
|
||||
v = static_cast<AccDataType>(v * scale_m * scale_n);
|
||||
}
|
||||
else if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
const auto sm = static_cast<float>(sm_tile.get_thread_buffer()[dst]);
|
||||
const auto sn = static_cast<float>(sn_tile.get_thread_buffer()[dst]);
|
||||
v = static_cast<AccDataType>(v * sm * sn);
|
||||
}
|
||||
|
||||
c_out_tensor.get_thread_buffer()[dst] = type_convert<ODataType>(v);
|
||||
});
|
||||
});
|
||||
|
||||
// store/update
|
||||
if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
|
||||
memory_operation_enum::set)
|
||||
{
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
|
||||
// advance output (and any D-tensors) by one MPerXdl*MWave chunk
|
||||
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
static_for<0, NumDTensor, 1>{}([&](auto idx) {
|
||||
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
typename DsDramWindows,
|
||||
typename ScaleM = EmptyScale,
|
||||
typename ScaleN = EmptyScale,
|
||||
int EnablePermuateN_ = TiledMMAPermuteN,
|
||||
std::enable_if_t<!EnablePermuateN_, int> = 0>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* p_smem,
|
||||
const ScaleM& scale_m = {},
|
||||
const ScaleN& scale_n = {})
|
||||
{
|
||||
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
|
||||
|
||||
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
|
||||
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<ODataType*>(p_smem), lds_block_desc);
|
||||
|
||||
auto in_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0},
|
||||
LdsTileDistr);
|
||||
|
||||
auto out_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
|
||||
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
Problem::kNumWaveGroups>;
|
||||
constexpr auto dram_tile_distribution =
|
||||
TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
|
||||
auto d_dram_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
constexpr bool has_scales =
|
||||
!std::is_same_v<ScaleM, EmptyScale> && !std::is_same_v<ScaleN, EmptyScale>;
|
||||
constexpr bool has_scalar_scales =
|
||||
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
|
||||
auto scale_m_window = [&]() {
|
||||
if constexpr(has_scalar_scales)
|
||||
{
|
||||
return scale_m;
|
||||
}
|
||||
else if constexpr(has_scales)
|
||||
{
|
||||
return make_tile_window(scale_m, lds_tile.get_tile_distribution());
|
||||
}
|
||||
else
|
||||
{
|
||||
return EmptyScale{};
|
||||
}
|
||||
}();
|
||||
auto scale_n_window = [&]() {
|
||||
if constexpr(has_scalar_scales)
|
||||
{
|
||||
return scale_n;
|
||||
}
|
||||
else if constexpr(has_scales)
|
||||
{
|
||||
return make_tile_window(scale_n, lds_tile.get_tile_distribution());
|
||||
}
|
||||
else
|
||||
{
|
||||
return EmptyScale{};
|
||||
}
|
||||
}();
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
block_sync_lds();
|
||||
slice_acc_tile<iAccess>(o_acc_tile, lds_tile);
|
||||
|
||||
if constexpr(has_scales)
|
||||
{
|
||||
scale_tile<iAccess>(lds_tile, scale_m_window, scale_n_window);
|
||||
}
|
||||
|
||||
cast_lds_tile(lds_tile, in_lds_window);
|
||||
block_sync_lds();
|
||||
|
||||
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
|
||||
apply_d_tensors(d_dram_windows, c_out_tensor);
|
||||
store_to_dram(out_dram_window, c_out_tensor);
|
||||
move_windows<iAccess>(out_dram_window, d_dram_windows);
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,91 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "default_2d_epilogue.hpp"
|
||||
#include "dynamic_quant_epilogue.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// User can reuse DynamicQuantEpilogueTraits with this epilogue
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool UseSmoothInputScale_,
|
||||
bool UseRawStore_ = true,
|
||||
bool UseMax3_ = false>
|
||||
using Default2DAndDynamicQuantEpilogueTraits =
|
||||
DynamicQuantEpilogueTraits<kPadM_, kPadN_, UseSmoothInputScale_, UseRawStore_, UseMax3_>;
|
||||
|
||||
// This epilogue just store out a M*N matrix, row major
|
||||
template <typename AccDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename ODataType_,
|
||||
typename UnquantYDataType_,
|
||||
typename BlockShape_,
|
||||
typename Traits_>
|
||||
struct Default2DAndDynamicQuantEpilogueProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using UnquantYDataType = remove_cvref_t<UnquantYDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct Default2DAndDynamicQuantEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using UnquantYDataType = remove_cvref_t<typename Problem::UnquantYDataType>;
|
||||
|
||||
static constexpr bool kPadM = Problem::Traits::kPadM;
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool UseRawStore = Problem::Traits::UseRawStore;
|
||||
|
||||
using Default2DProblem =
|
||||
Default2DEpilogueProblem<AccDataType, UnquantYDataType, kPadM, kPadN, UseRawStore>;
|
||||
using Default2D = Default2DEpilogue<Default2DProblem>;
|
||||
using DynamicQuant = DynamicQuantEpilogue<Problem>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return max(Default2D::GetSmemSize(), DynamicQuant::GetSmemSize());
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmpD,
|
||||
typename ODramWindowTmpQ,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename OAccTile>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmpD& o_direct_dram_window_tmp,
|
||||
ODramWindowTmpQ& o_quant_dram_window_tmp,
|
||||
const SmoothScaleWindow& sm_scale_window_,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
Default2D{}(o_direct_dram_window_tmp, o_acc_tile, smem);
|
||||
DynamicQuant{}(o_quant_dram_window_tmp, sm_scale_window_, y_scale_window, o_acc_tile, smem);
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmpD,
|
||||
typename ODramWindowTmpQ,
|
||||
typename YScaleWindow,
|
||||
typename OAccTile>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmpD& o_direct_dram_window_tmp,
|
||||
ODramWindowTmpQ& o_quant_dram_window_tmp,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
Default2D{}(o_direct_dram_window_tmp, o_acc_tile, smem);
|
||||
DynamicQuant{}(o_quant_dram_window_tmp, y_scale_window, o_acc_tile, smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
291
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
Normal file
291
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
Normal file
@@ -0,0 +1,291 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this epilogue just store out a M*N matrix, row major
|
||||
|
||||
template <typename AccDataType_,
|
||||
typename ODataType_,
|
||||
bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool UseRawStore_ = true>
|
||||
struct Default2DEpilogueProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool UseRawStore = UseRawStore_;
|
||||
static constexpr index_t NumDTensor = 0;
|
||||
};
|
||||
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename DsDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename DsLayout_,
|
||||
typename CLayout_,
|
||||
typename CDElementwise_,
|
||||
index_t kM_,
|
||||
index_t kN_,
|
||||
bool kPadM_,
|
||||
bool kPadN_,
|
||||
index_t kMPerXdl_,
|
||||
index_t kNPerXdl_,
|
||||
index_t kKPerXdl_,
|
||||
bool isCTransposed_,
|
||||
bool UseRawStore_ = true>
|
||||
struct DefaultGemm2DEpilogueProblem
|
||||
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
|
||||
{
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using CLayout = remove_cvref_t<CLayout_>;
|
||||
using DsDataType = remove_cvref_t<DsDataType_>;
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
using DsLayout = remove_cvref_t<DsLayout_>;
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
static constexpr index_t kMPerXdl = kMPerXdl_;
|
||||
static constexpr index_t kNPerXdl = kNPerXdl_;
|
||||
static constexpr index_t kKPerXdl = kKPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct Default2DEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool UseRawStore = Problem::UseRawStore;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
|
||||
// how do we fix this ?
|
||||
template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* = nullptr) const
|
||||
{
|
||||
constexpr bool is_partition_index =
|
||||
std::is_convertible_v<decltype(ds_dram_windows),
|
||||
decltype(get_partition_index(
|
||||
o_acc_tile.get_tile_distribution()))>;
|
||||
|
||||
const auto storeOrUpdateTile = [&](const auto& o_tile) {
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
// FIXME?
|
||||
// if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp ==
|
||||
// memory_operation_enum::set)
|
||||
if constexpr(true)
|
||||
{
|
||||
if constexpr(is_partition_index)
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp,
|
||||
cast_tile<ODataType>(o_tile),
|
||||
/*partition_index=*/ds_dram_windows);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
}
|
||||
buffer_store_fence();
|
||||
}
|
||||
else
|
||||
{
|
||||
// FIXME?
|
||||
// if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp ==
|
||||
// memory_operation_enum::set)
|
||||
if constexpr(true)
|
||||
{
|
||||
if constexpr(is_partition_index)
|
||||
{
|
||||
store_tile(o_dram_window_tmp,
|
||||
cast_tile<ODataType>(o_tile),
|
||||
/*partition_index=*/ds_dram_windows);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(is_partition_index)
|
||||
{
|
||||
update_tile(o_dram_window_tmp,
|
||||
cast_tile<ODataType>(o_tile),
|
||||
/*partition_index=*/ds_dram_windows);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr(!std::is_same_v<DsDramWindows, std::nullptr_t> && !is_partition_index &&
|
||||
Problem::NumDTensor >= 1)
|
||||
{
|
||||
using elementwise_result_t = decltype(load_tile(
|
||||
make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(),
|
||||
make_tuple(Problem::kMPerBlock, Problem::kNPerBlock),
|
||||
ds_dram_windows[number<0>{}].get_window_origin(),
|
||||
o_acc_tile.get_tile_distribution())));
|
||||
|
||||
elementwise_result_t elementwise_result;
|
||||
|
||||
const auto d_tensor_tuple = generate_tuple(
|
||||
[&](auto idx) {
|
||||
const auto d_tile_window =
|
||||
make_tile_window(ds_dram_windows[idx], o_acc_tile.get_tile_distribution());
|
||||
return load_tile(d_tile_window);
|
||||
},
|
||||
number<Problem::NumDTensor>{});
|
||||
|
||||
const auto c_d_tuple = concat_tuple_of_reference(
|
||||
tie(elementwise_result, o_acc_tile),
|
||||
generate_tie([&](auto idx) -> const auto& { return d_tensor_tuple[idx]; },
|
||||
number<Problem::NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_d_tuple);
|
||||
|
||||
storeOrUpdateTile(elementwise_result);
|
||||
}
|
||||
else
|
||||
{
|
||||
storeOrUpdateTile(o_acc_tile);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
|
||||
|
||||
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<AsDataType>,
|
||||
remove_cvref_t<tuple<AsDataType>>>;
|
||||
|
||||
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<BsDataType>,
|
||||
remove_cvref_t<tuple<BsDataType>>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
|
||||
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
|
||||
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
|
||||
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
|
||||
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
|
||||
using WG = WarpGemmDispatcher<ADataType,
|
||||
BTypeToUse,
|
||||
AccDataType,
|
||||
kMPerXdl,
|
||||
kNPerXdl,
|
||||
kKPerXdl,
|
||||
isCTransposed>;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
|
||||
{
|
||||
// N is contiguous dimension
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(isCTransposed)
|
||||
{
|
||||
// In this case each thread has multiple consecutive elements in
|
||||
// N dimension, however consecutive threads' elements have stride.
|
||||
constexpr index_t NDimY = CWarpDstr::NDimY;
|
||||
constexpr auto c_warp_y_lengths =
|
||||
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
|
||||
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
|
||||
c_warp_y_lengths.get(number<NDimY - 1>{}));
|
||||
return c_warp_y_lengths.get(number<NDimY - 1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// In this case each thread has just a single item in Ndim
|
||||
return (WG::WarpGemmAttribute::Impl::kCNLane *
|
||||
WG::WarpGemmAttribute::Impl::kBNBlock) /
|
||||
WG::kN;
|
||||
}
|
||||
}
|
||||
// M is contiguous dimension
|
||||
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
if constexpr(isCTransposed)
|
||||
{
|
||||
// In this case each thread has just a single item in Mdim
|
||||
return (WG::WarpGemmAttribute::Impl::kCNLane *
|
||||
WG::WarpGemmAttribute::Impl::kAMBlock) /
|
||||
WG::kN;
|
||||
}
|
||||
else
|
||||
{
|
||||
// In this case each thread has multiple consecutive elements in
|
||||
// M dimension, however consecutive threads' elements have stride.
|
||||
constexpr index_t NDimY = CWarpDstr::NDimY;
|
||||
constexpr auto c_warp_y_lengths =
|
||||
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
|
||||
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
|
||||
c_warp_y_lengths.get(number<NDimY - 1>{}));
|
||||
return c_warp_y_lengths.get(number<NDimY - 1>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported CLayout!");
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD([[maybe_unused]] number<I> index)
|
||||
{
|
||||
return GetVectorSizeC();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
212
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
Normal file
212
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
Normal file
@@ -0,0 +1,212 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool UseSmoothInputScale_,
|
||||
bool UseRawStore_ = true,
|
||||
bool UseMax3_ = false>
|
||||
struct DynamicQuantEpilogueTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool UseSmoothInputScale = UseSmoothInputScale_;
|
||||
static constexpr bool UseRawStore = UseRawStore_;
|
||||
static constexpr bool UseMax3 = UseMax3_;
|
||||
};
|
||||
|
||||
// this epilogue just store out a M*N matrix, row major
|
||||
template <typename AccDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename ODataType_,
|
||||
typename BlockShape_,
|
||||
typename Traits_>
|
||||
struct DynamicQuantEpilogueProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
};
|
||||
|
||||
// TODO: we should put descriptor creation function into policy
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct DynamicQuantEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
|
||||
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using BlockShape = remove_cvref_t<typename Problem::BlockShape>;
|
||||
static constexpr bool kPadM = Problem::Traits::kPadM;
|
||||
static constexpr bool kPadN = Problem::Traits::kPadN;
|
||||
static constexpr bool UseRawStore = Problem::Traits::UseRawStore;
|
||||
static constexpr bool UseMax3 = Problem::Traits::UseMax3;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<AccDataType, AccDataType, BlockShape>;
|
||||
return BlockReduce2d<P_>{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<AccDataType, AccDataType, BlockShape>;
|
||||
return BlockReduce2dSync<P_>{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<AccDataType, AccDataType, BlockShape>;
|
||||
return BlockReduce2dCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeSmoothInputScaleTileDistribution()
|
||||
{
|
||||
using S = BlockShape;
|
||||
#if 0
|
||||
// don't remove this
|
||||
// Note that if we set encoding purposely like this, you will result in compile fail
|
||||
// TODO: sm_scale create local-scratch to accept arbitrary acc input (with same length)
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
|
||||
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>,
|
||||
sequence<0, 1, 1>,
|
||||
sequence<0, 0, 3>>{});
|
||||
#else
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
|
||||
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 3>>{});
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
|
||||
return reduce_crosswarp_sync.GetSmemSize();
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
|
||||
CK_TILE_DEVICE auto Impl(ODramWindowTmp& o_dram_window_tmp,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
auto reduce = GetBlockReduce2d();
|
||||
auto reduce_sync = GetBlockReduce2dSync();
|
||||
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
|
||||
|
||||
auto o_acc_tmp = o_acc_tile;
|
||||
|
||||
const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
|
||||
|
||||
auto row_absmax = [&]() {
|
||||
constexpr auto y_size_per_row =
|
||||
OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(
|
||||
number<1>{});
|
||||
if constexpr(UseMax3 && std::is_same_v<AccDataType, float> && y_size_per_row % 2 == 0)
|
||||
{
|
||||
// fast max3+abs implementation
|
||||
const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) {
|
||||
float rtn;
|
||||
asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
|
||||
: "=v"(rtn)
|
||||
: "v"(acc_), "v"(v_0_), "v"(v_1_));
|
||||
return rtn;
|
||||
};
|
||||
return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_absmax);
|
||||
}
|
||||
}();
|
||||
reduce_sync(row_absmax, f_absmax);
|
||||
reduce_crosswarp_sync(row_absmax, smem, f_absmax);
|
||||
|
||||
// here y_scale is Acc TYpe, need convert to YScale type later
|
||||
auto y_scale = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return v_ / type_convert<AccDataType>(numeric<ODataType>::max());
|
||||
},
|
||||
row_absmax);
|
||||
|
||||
store_tile(y_scale_window, cast_tile<YScaleDataType>(y_scale));
|
||||
|
||||
sweep_tile(o_acc_tmp, [&](auto idx) {
|
||||
constexpr auto row_id = make_tuple(idx[number<0>{}]);
|
||||
o_acc_tmp(idx) = o_acc_tmp[idx] / y_scale(row_id);
|
||||
});
|
||||
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
|
||||
buffer_store_fence();
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
|
||||
// how do we fix this ?
|
||||
|
||||
// Smooth Dynamic Quant
|
||||
template <typename ODramWindowTmp,
|
||||
typename SmoothScaleWindow,
|
||||
typename YScaleWindow,
|
||||
typename OAccTile>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
const SmoothScaleWindow& sm_scale_window_,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
const auto sm_scale_window =
|
||||
make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution());
|
||||
|
||||
auto sm_scale = load_tile(sm_scale_window);
|
||||
|
||||
auto o_acc_tmp = o_acc_tile;
|
||||
|
||||
sweep_tile(o_acc_tmp, [&](auto idx) {
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
const auto xs_ = type_convert<AccDataType>(sm_scale[j_idx]);
|
||||
o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
|
||||
});
|
||||
|
||||
Impl(o_dram_window_tmp, y_scale_window, o_acc_tmp, smem);
|
||||
}
|
||||
|
||||
// Dynamic Quant
|
||||
template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
Impl(o_dram_window_tmp, y_scale_window, o_acc_tile, smem);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user