This commit is contained in:
Ding, Yi
2026-03-11 23:03:20 -04:00
commit e6cd3f1e3f
6330 changed files with 1132789 additions and 0 deletions

View File

@@ -0,0 +1,61 @@
# CK Tile Epilogue Chainer
## Overview
The Epilogue Chainer provides a modular epilogue processing framework through scheduler-defined operation graphs.
## Architecture
### Core Design Principle
The chainer follows a **Scheduler-Graph-Node** architecture with shared context:
- **Scheduler**: Defines operation graphs and creates a shared context
- **Graph**: Composes multiple operations into sequential processing units
- **Node**: Wraps individual epilogue operations with their arguments
### EpilogueChainer
The `EpilogueChainer` struct serves as the modular epilogue processing facilitator. It delegates to schedulers for context creation and schedule generation, then processes the resulting operation graphs.
### EpilogueNode
Individual epilogue operations are wrapped in `EpilogueNode` structures that capture required arguments at construction time and automatically forward them during processing. Supports both parameterized and parameter-free operations.
### EpilogueGraph
The `EpilogueGraph` composes multiple nodes into sequential processing units that iterate over multiple accesses if needed, running all operations in order for each iteration.
## Files
### Core Infrastructure
- `epilogue_chainer.hpp` - General chainer, node, and graph infrastructure
- `common_epilogue_ops.hpp` - Epilogue operations usable with any epilogue type
### CShuffle Implementation
- `cshuffle_epilogue_chainer_ops.hpp` - CShuffle-specific problem, context, and slice operations
- `cshuffle_epilogue_schedule.hpp` - CShuffle scheduler with pre-built schedules
## Usage
### Common Operations (common_epilogue_ops.hpp)
These operations work with any context that provides the standardized interface:
- `ScaleScalarOp` - Scale working-tile by scalar values
- `CastAndStoreToLdsOp<DstType>` - Cast working-tile and store to LDS
- `LoadFromLdsOp<Pattern>` - Load output tile from LDS with sync
- `ElementwiseOp<Func, NumAux>` - Apply elementwise operation with auxiliary tensors
- `StoreOp<MemOp>` - Store output tile to global memory
- `MoveWindowsOp<SFC, NumAux>` - Advance windows to next position
### CShuffle-Specific Operations (cshuffle_epilogue_chainer_ops.hpp)
These operations are specific to CShuffle epilogue:
- `CShuffleSliceOp` - Slice accumulator tile based on distribution
- `CShuffleScaleWindowOp` - Scale using tensor windows with shuffle distribution
### Context Interface
Operations communicate through a shared context with standardized members:
- `working_tile`: Tile for intermediate computations
- `out_tile`: Output tile
- `aux_windows`: Tuple of auxiliary tensor windows
- `lds_write_window`: Window for writing to LDS
- `lds_read_window`: Window for reading from LDS
### Schedule Tags
- `DefaultScheduleTag` - Standard: Slice → CastStore → Load → ApplyD → Store → Move
- `RowColQuantScheduleTag` - With window scaling
- `TensorQuantScheduleTag` - With scalar scaling

View File

@@ -0,0 +1,208 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
/// @file common_epilogue_ops.hpp
/// @brief Reusable simple epilogue operations which might be used to compose more complex one.
///
///
/// @par Context Interface
/// Operations expect the context to provide:
/// - working_tile: Tile for intermediate computations
/// - out_tile: Output tile for final results
/// - aux_windows: Tuple of auxiliary tensor windows (e.g., D tensors)
/// - lds_write_window: Window for writing to LDS (if using LDS)
/// - lds_read_window: Window for reading from LDS (if using LDS)
namespace ck_tile {
/// @brief Scale working tile by scalar values
///
/// @par Context Requirements
/// working_tile: Tile to scale
struct ScaleScalarOp
{
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename IAccess,
typename Context,
typename ScaleA,
typename ScaleB>
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
[[maybe_unused]] const AccTile& acc_tile,
[[maybe_unused]] const AuxWindows& aux_windows,
[[maybe_unused]] void* p_smem,
[[maybe_unused]] IAccess iAccess,
Context& context,
const ScaleA& scale_a,
const ScaleB& scale_b)
{
tile_elementwise_inout([&](auto& elem) { elem = elem * scale_a * scale_b; },
context.working_tile);
}
};
/// @brief Cast working tile and store to LDS
///
/// @tparam DataType Target data type for casting
///
/// @par Context Requirements
/// working_tile: Tile to cast
/// lds_write_window: Window for writing to LDS
template <typename DataType>
struct CastAndStoreToLdsOp
{
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename IAccess,
typename Context>
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
[[maybe_unused]] const AccTile& acc_tile,
[[maybe_unused]] const AuxWindows& aux_windows,
[[maybe_unused]] void* p_smem,
[[maybe_unused]] IAccess iAccess,
Context& context)
{
const auto casted_tile = cast_tile<DataType>(context.working_tile);
store_tile(context.lds_write_window, casted_tile);
}
};
/// @brief Load output tile from LDS with synchronization
///
/// @tparam TileEncodingPattern Pattern for tile distribution
///
/// @par Context Requirements
/// lds_read_window: Window for reading from LDS
/// out_tile: Destination for loaded tile
template <typename TileEncodingPattern>
struct LoadFromLdsOp
{
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename IAccess,
typename Context>
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
[[maybe_unused]] const AccTile& acc_tile,
[[maybe_unused]] const AuxWindows& aux_windows,
[[maybe_unused]] void* p_smem,
[[maybe_unused]] IAccess iAccess,
Context& context)
{
constexpr auto tile_distribution = TileEncodingPattern::make_2d_static_tile_distribution();
block_sync_lds();
context.out_tile = load_tile(make_tile_window(context.lds_read_window, tile_distribution));
}
};
/// @brief Apply elementwise operation with auxiliary tensors
///
/// @tparam Elementwise Elementwise functor type
/// @tparam NumAux Number of auxiliary tensors to load and apply
///
/// @par Context Requirements
/// out_tile: In/out tile for elementwise operation
/// aux_windows: Tuple of auxiliary tensor windows
template <typename Elementwise, index_t NumAux>
struct ElementwiseOp
{
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename IAccess,
typename Context>
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
[[maybe_unused]] const AccTile& acc_tile,
[[maybe_unused]] const AuxWindows& aux_windows,
[[maybe_unused]] void* p_smem,
[[maybe_unused]] IAccess iAccess,
Context& context)
{
const auto aux_tiles = generate_tuple(
[&](auto idx) { return load_tile(context.aux_windows[idx]); }, number<NumAux>{});
const auto tiles = concat_tuple_of_reference(
tie(context.out_tile, context.out_tile),
generate_tie([&](auto idx) -> const auto& { return aux_tiles[idx]; },
number<NumAux>{}));
tile_elementwise_inout_unpack(Elementwise{}, tiles);
}
};
/// @brief Store output tile to global memory
///
/// @tparam MemOp Memory operation type (set or atomic_add)
///
/// @par Context Requirements
/// out_tile: Tile to store
template <memory_operation_enum MemOp>
struct StoreOp
{
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename IAccess,
typename Context>
CK_TILE_DEVICE void operator()(OutWindow& out_window,
[[maybe_unused]] const AccTile& acc_tile,
[[maybe_unused]] const AuxWindows& aux_windows,
[[maybe_unused]] void* p_smem,
[[maybe_unused]] IAccess iAccess,
Context& context)
{
if constexpr(MemOp == memory_operation_enum::set)
{
store_tile(out_window, context.out_tile);
}
else
{
update_tile(out_window, context.out_tile);
}
}
};
/// @brief Move output and auxiliary windows by step from space-filling curve
///
/// @tparam SFC Space filling curve type providing step computation
/// @tparam NumAux Number of auxiliary windows to move
///
/// @par Context Requirements
/// aux_windows: Tuple of windows to move
template <typename SFC, index_t NumAux>
struct MoveWindowsOp
{
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename IAccess,
typename Context>
CK_TILE_DEVICE void operator()(OutWindow& out_window,
[[maybe_unused]] const AccTile& acc_tile,
[[maybe_unused]] const AuxWindows& aux_windows,
[[maybe_unused]] void* p_smem,
IAccess iAccess,
Context& context)
{
constexpr index_t num_access = SFC::get_num_of_access();
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(out_window, {step.at(number<0>{}), step.at(number<1>{})});
static_for<0, NumAux, 1>{}([&](auto idx) {
move_tile_window(context.aux_windows[idx],
{step.at(number<0>{}), step.at(number<1>{})});
});
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,512 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include <optional>
namespace ck_tile {
//------------------------------------------------------------------------------
// CShuffle-specific epilogue operations
// These operations are specific to CShuffle epilogue due to its unique.
//------------------------------------------------------------------------------
/// @brief Slice accumulator tile for CShuffle epilogue
///
/// @par Purpose
/// Extracts a portion of the accumulator tile into the working tile
/// based on the current iteration index. This is CShuffle-specific.
///
/// @tparam SFC Space filling curve type
/// @tparam CWarpDstr Warp distribution type
/// @tparam NumMXdlPerWavePerShuffle XDL tiles in M per wave per shuffle
/// @tparam NumNXdlPerWavePerShuffle XDL tiles in N per wave per shuffle
/// @tparam MPerIterShuffle M elements per shuffle iteration
/// @tparam NPerIterShuffle N elements per shuffle iteration
template <typename SFC,
typename CWarpDstr,
index_t NumMXdlPerWavePerShuffle,
index_t NumNXdlPerWavePerShuffle,
index_t MPerIterShuffle,
index_t NPerIterShuffle>
struct CShuffleSliceOp
{
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename IAccess,
typename Context>
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
const AccTile& acc_tile,
[[maybe_unused]] const AuxWindows& aux_windows,
[[maybe_unused]] void* p_smem,
IAccess iAccess,
Context& context)
{
constexpr auto idx_start = SFC::get_index(iAccess);
constexpr auto m_iter = number<idx_start.at(number<0>{}) / MPerIterShuffle>{};
constexpr auto n_iter = number<idx_start.at(number<1>{}) / NPerIterShuffle>{};
constexpr auto warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
context.working_tile.get_thread_buffer() = acc_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<m_iter * NumMXdlPerWavePerShuffle, n_iter * NumNXdlPerWavePerShuffle>{},
warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
warp_y_lengths));
}
};
/// @brief Scale working tile using tensor windows (CShuffle-specific)
///
/// @par Purpose
/// Scales the working tile using row and column scale tensors.
/// CShuffle-specific because it creates scale windows from the
/// working tile's distribution and handles window movement.
///
/// @tparam SFC Space filling curve type
template <typename SFC>
struct CShuffleScaleWindowOp
{
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename IAccess,
typename Context,
typename ScaleRowTensor,
typename ScaleColTensor>
CK_TILE_DEVICE void operator()([[maybe_unused]] OutWindow& out_window,
[[maybe_unused]] const AccTile& acc_tile,
[[maybe_unused]] const AuxWindows& aux_windows,
[[maybe_unused]] void* p_smem,
IAccess iAccess,
Context& context,
const ScaleRowTensor& scale_row_tensor,
const ScaleColTensor& scale_col_tensor)
{
auto scale_row_window =
make_tile_window(scale_row_tensor, context.working_tile.get_tile_distribution());
auto scale_col_window =
make_tile_window(scale_col_tensor, context.working_tile.get_tile_distribution());
const auto scale_row_tile = load_tile(scale_row_window);
const auto scale_col_tile = load_tile(scale_col_window);
tile_elementwise_inout(element_wise::MultiDMultiply{},
context.working_tile,
context.working_tile,
scale_row_tile,
scale_col_tile);
constexpr index_t num_access = SFC::get_num_of_access();
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(number<iAccess>{});
move_tile_window(scale_row_window, {step.at(number<0>{}), step.at(number<1>{})});
move_tile_window(scale_col_window, {step.at(number<0>{}), step.at(number<1>{})});
}
}
};
//------------------------------------------------------------------------------
// CShuffle problem and base operation definitions
//------------------------------------------------------------------------------
/// @brief Problem configuration for CShuffle epilogue chainer operations
/// @note Mirrors CShuffleEpilogueProblem but uses AsDataType/BsDataType for tuple support
template <typename AsDataType_,
typename BsDataType_,
typename DsDataType_,
typename AccDataType_,
typename ODataType_,
typename DsLayout_,
typename ELayout_,
typename CDElementwise_,
index_t kM_,
index_t kN_,
index_t MWave_,
index_t NWave_,
index_t MPerXdl_,
index_t NPerXdl_,
index_t KPerXdl_,
bool isCTransposed_,
memory_operation_enum MemoryOperation_,
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false,
index_t BlockedXDLN_PerWarp_ = 1>
struct CShuffleEpilogueChainProblem
{
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t MWave = MWave_;
static constexpr index_t NWave = NWave_;
static constexpr index_t MPerXdl = MPerXdl_;
static constexpr index_t NPerXdl = NPerXdl_;
static constexpr index_t KPerXdl = KPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
};
template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogueChainBaseOp
{
using Problem = remove_cvref_t<Problem_>;
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
remove_cvref_t<AsDataType>,
remove_cvref_t<tuple<AsDataType>>>;
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
remove_cvref_t<BsDataType>,
remove_cvref_t<tuple<BsDataType>>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using ELayout = remove_cvref_t<typename Problem::ELayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t MWave = Problem::MWave;
static constexpr index_t NWave = Problem::NWave;
static constexpr index_t MPerXdl = Problem::MPerXdl;
static constexpr index_t NPerXdl = Problem::NPerXdl;
static constexpr index_t KPerXdl = Problem::KPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
/**
* @brief Get the vector store size for C tensor.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
{
if constexpr(FixedVectorSize)
{
return VectorSizeC;
}
constexpr index_t max_vector_size = 16;
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(NPerIteration),
static_cast<int>(max_vector_size / sizeof(ODataType)));
}
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(MPerIteration),
static_cast<int>(max_vector_size / sizeof(ODataType)));
}
else
{
static_assert(false, "Unsupported ELayout!");
}
}
/**
* @brief Get the vector store size for Di tensor.
*
* @return The vector store size for Di tensor.
*/
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> index)
{
constexpr index_t max_vector_size = 16;
using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(NPerIteration),
static_cast<int>(max_vector_size / sizeof(DiDataType)));
}
else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(MPerIteration),
static_cast<int>(max_vector_size / sizeof(DiDataType)));
}
else
{
static_assert(false, "Unsupported DLayout!");
}
}
/**
* @brief Shuffle tile configuration parameters
*
* @details These parameters control the number of XDL tiles processed per wave in each shuffle
* iteration:
* - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave
* - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave
*/
static constexpr auto shuffle_tile_tuple = [] {
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
if constexpr(elem_per_thread >= GetVectorSizeC())
{
return std::make_tuple(1, 1);
}
else
{
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
(kMPerBlock % num_xdl_shuffles == 0),
"kMPerBlock must be divisible by MPerXdl*MWave and "
"num_xdl_shuffles for CShuffleEpilogueStageBase");
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
}
else
{
static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
(kNPerBlock % num_xdl_shuffles == 0),
"kNPerBlock must be divisible by NPerXdl*NWave and "
"num_xdl_shuffles for CShuffleEpilogueStageBase");
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
}
}
}();
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
static constexpr index_t NumNXdlPerWavePerShuffle =
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
static constexpr auto MNPerIterationShuffle = [] {
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
else
return std::make_tuple(m_val, n_val);
}();
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
using WG = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
AccDataType,
MPerXdl,
NPerXdl,
KPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{
// N is contiguous dimension
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<NPerIterationShuffle>{}, number<1>{}));
}
// M is contiguous dimension
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<1>{}, number<MPerIterationShuffle>{}));
}
else
{
static_assert(false, "Unsupported ELayout!");
}
}
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
{
constexpr auto block_outer_dstr_encoding = [] {
if constexpr(BlockedXDLN_PerWarp == 1)
{
return tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
}
else
{
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
// BlockedLayout
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
}
}();
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
return block_dstr_encoding;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
}
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<kBlockSize,
MPerIterationShuffle,
NPerIterationShuffle,
GetVectorSizeC(),
tile_distribution_pattern::thread_raked,
Problem::kNumWaveGroups>;
/// @brief Context structure for CShuffle epilogue operations
///
/// @par Purpose
/// The context serves as a shared workspace that maintains intermediate results
/// and resources across multiple epilogue operations. It eliminates the need for
/// operations to recreate shared data structures and enables data flow
/// through the operation graph.
///
/// @par Standardized Interface
/// Uses standardized member names so common operations can work with this context:
/// - working_tile: Intermediate tile for shuffle operations
/// - out_tile: Output tile for final results
/// - aux_windows: Auxiliary tensor windows (D tensors)
/// - lds_write_window: Window for writing to LDS
/// - lds_read_window: Window for reading from LDS
template <typename WorkingTileType,
typename LdsBlockType,
typename LdsWriteWindowType,
typename LdsReadWindowType,
typename AuxWindowsType,
typename OutTileType>
struct CShuffleContext
{
WorkingTileType working_tile; // Working tile for shuffle operations
LdsBlockType lds_block; // LDS block view
LdsWriteWindowType lds_write_window; // Window for writing to LDS
LdsReadWindowType lds_read_window; // Window for reading from LDS
AuxWindowsType aux_windows; // Auxiliary tensor windows (D tensors)
OutTileType out_tile; // Output tile
};
template <typename OutDramWindow, typename AccTile, typename DsDramWindows>
CK_TILE_DEVICE auto operator()([[maybe_unused]] OutDramWindow& out_window,
[[maybe_unused]] const AccTile& acc_tile,
const DsDramWindows& ds_windows,
void* p_smem)
{
static_assert(
std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
"Currently, the CShuffleEpilogueStageBase only supports the Row Major Output layout");
constexpr auto working_tile_distr =
make_static_tile_distribution(MakeLdsDistributionEncode());
auto working_tile = make_static_distributed_tensor<AccDataType>(working_tile_distr);
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
auto lds_block = make_tensor_view<address_space_enum::lds>(static_cast<ODataType*>(p_smem),
lds_block_desc);
auto lds_write_window = make_tile_window(
lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0},
working_tile_distr);
auto lds_read_window = make_tile_window(
lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
constexpr auto dram_tile_distribution =
TileEncodingPattern::make_2d_static_tile_distribution();
auto aux_windows = generate_tuple(
[&](auto idx) { return make_tile_window(ds_windows[idx], dram_tile_distribution); },
number<NumDTensor>{});
auto out_tile = load_tile(make_tile_window(lds_read_window, dram_tile_distribution));
using ContextType = CShuffleContext<decltype(working_tile),
decltype(lds_block),
decltype(lds_write_window),
decltype(lds_read_window),
decltype(aux_windows),
decltype(out_tile)>;
ContextType context;
context.working_tile = working_tile;
context.lds_block = lds_block;
context.lds_write_window = lds_write_window;
context.lds_read_window = lds_read_window;
context.aux_windows = aux_windows;
context.out_tile = out_tile;
return context;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,129 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/epilogue/chainer/epilogue_chainer.hpp"
#include "ck_tile/ops/epilogue/chainer/common_epilogue_ops.hpp"
#include "ck_tile/ops/epilogue/chainer/cshuffle_epilogue_chainer_ops.hpp"
namespace ck_tile {
/// @brief Schedule type tags for epilogue selection
/// @par Purpose
/// Each tag corresponds to a pre-built schedule, these are used to select a schedule
/// Standard epilogue schedule: Slice → CastStore → Load → ApplyD → Store → Move
struct DefaultScheduleTag
{
};
/// RowCol quantization schedule: Slice → ScaleWindow → CastStore → Load → ApplyD → Store → Move
struct RowColQuantScheduleTag
{
};
/// Tensor quantization schedule: Slice → ScaleScalar → CastStore → Load → ApplyD → Store → Move
struct TensorQuantScheduleTag
{
};
/// @brief CShuffle epilogue scheduler providing pre-built schedules
///
/// @par Overview
/// CshuffleEpilogueSchedule acts as the scheduler component for EpilogueChainer.
/// It provides context creation and pre-built schedules. The scheduler
/// uses tags to select/create appropriate epilogue schedule.
///
/// @tparam Problem The epilogue problem configuration
/// @tparam ScheduleTag Tag selecting the epilogue schedule
template <typename Problem, typename ScheduleTag = DefaultScheduleTag>
struct CshuffleEpilogueSchedule
{
using ProblemType = Problem;
using BaseOp = CShuffleEpilogueChainBaseOp<Problem>;
static constexpr index_t NumAccess = BaseOp::SFC::get_num_of_access();
/// @brief Create context for epilogue operations
template <typename OutWindow, typename AccTile, typename AuxWindows>
CK_TILE_DEVICE static auto create_context(OutWindow& out_window,
const AccTile& acc_tile,
const AuxWindows& aux_windows,
void* p_smem)
{
return BaseOp{}(out_window, acc_tile, aux_windows, p_smem);
}
/// @brief Make schedule based on compile-time tag selection
template <typename... Args>
CK_TILE_DEVICE static auto make_schedule(Args&&... args)
{
if constexpr(std::is_same_v<ScheduleTag, DefaultScheduleTag>)
{
// Standard epilogue
// Schedule: Slice -> CastAndStoreLds -> Load -> ApplyD -> Store -> MoveWindows
static_assert(sizeof...(args) == 0, "DefaultSchedule expects no arguments");
return make_graph<NumAccess>(
make_node<CShuffleSliceOp<typename BaseOp::SFC,
typename BaseOp::CWarpDstr,
BaseOp::NumMXdlPerWavePerShuffle,
BaseOp::NumNXdlPerWavePerShuffle,
BaseOp::MPerIterationShuffle,
BaseOp::NPerIterationShuffle>>(),
make_node<CastAndStoreToLdsOp<typename BaseOp::ODataType>>(),
make_node<LoadFromLdsOp<typename BaseOp::TileEncodingPattern>>(),
make_node<ElementwiseOp<typename Problem::CDElementwise, Problem::NumDTensor>>(),
make_node<StoreOp<Problem::MemoryOperation>>(),
make_node<MoveWindowsOp<typename BaseOp::SFC, Problem::NumDTensor>>());
}
else if constexpr(std::is_same_v<ScheduleTag, RowColQuantScheduleTag>)
{
// RowCol quantization schedule with tensor windows
// Schedule: Slice -> ScaleWindow -> CastAndStoreLds -> Load -> ApplyD -> Store ->
// MoveWindows
static_assert(sizeof...(args) == 2,
"RowColQuantSchedule requires exactly 2 scale tensor arguments");
return make_graph<NumAccess>(
make_node<CShuffleSliceOp<typename BaseOp::SFC,
typename BaseOp::CWarpDstr,
BaseOp::NumMXdlPerWavePerShuffle,
BaseOp::NumNXdlPerWavePerShuffle,
BaseOp::MPerIterationShuffle,
BaseOp::NPerIterationShuffle>>(),
make_node<CShuffleScaleWindowOp<typename BaseOp::SFC>>(std::forward<Args>(args)...),
make_node<CastAndStoreToLdsOp<typename BaseOp::ODataType>>(),
make_node<LoadFromLdsOp<typename BaseOp::TileEncodingPattern>>(),
make_node<ElementwiseOp<typename Problem::CDElementwise, Problem::NumDTensor>>(),
make_node<StoreOp<Problem::MemoryOperation>>(),
make_node<MoveWindowsOp<typename BaseOp::SFC, Problem::NumDTensor>>());
}
else if constexpr(std::is_same_v<ScheduleTag, TensorQuantScheduleTag>)
{
// Tensor quantization schedule with scalar values
// Schedule: Slice -> ScaleScalar -> CastAndStoreLds -> Load -> ApplyD -> Store ->
// MoveWindows
static_assert(sizeof...(args) == 2,
"TensorQuantSchedule requires exactly 2 scalar arguments");
return make_graph<NumAccess>(
make_node<CShuffleSliceOp<typename BaseOp::SFC,
typename BaseOp::CWarpDstr,
BaseOp::NumMXdlPerWavePerShuffle,
BaseOp::NumNXdlPerWavePerShuffle,
BaseOp::MPerIterationShuffle,
BaseOp::NPerIterationShuffle>>(),
make_node<ScaleScalarOp>(std::forward<Args>(args)...),
make_node<CastAndStoreToLdsOp<typename BaseOp::ODataType>>(),
make_node<LoadFromLdsOp<typename BaseOp::TileEncodingPattern>>(),
make_node<ElementwiseOp<typename Problem::CDElementwise, Problem::NumDTensor>>(),
make_node<StoreOp<Problem::MemoryOperation>>(),
make_node<MoveWindowsOp<typename BaseOp::SFC, Problem::NumDTensor>>());
}
else
{
static_assert(false, "Unknown schedule tag");
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,213 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile {
/// @brief Epilogue Chainer - Modular epilogue processing facilitator
///
/// @par Overview
/// EpilogueChainer provides an interface for processing epilogue operations
/// through schedules. The chainer uses decomposed epilogue operations, these are
/// scheduled/sequenced by a Scheduler to form operation graphs.
///
/// @tparam Scheduler The schedule provider that defines epilogue operation graphs
template <typename Scheduler>
struct EpilogueChainer
{
using Problem = typename Scheduler::ProblemType;
using BaseOp = typename Scheduler::BaseOp;
using ODataType = typename BaseOp::ODataType;
using DsDataType = typename BaseOp::DsDataType;
using DsLayout = typename BaseOp::DsLayout;
using AccDataType = typename BaseOp::AccDataType;
static constexpr auto MemoryOperation = BaseOp::MemoryOperation;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return BaseOp::GetSmemSize(); }
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
{
return BaseOp::GetVectorSizeC();
}
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> idx)
{
return BaseOp::GetVectorSizeD(idx);
}
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
{
return BaseOp::MakeLdsDistributionEncode();
}
/// @brief Process epilogue through scheduler-defined operation graph
///
/// @par Flow
/// 1. Create shared context through scheduler
/// 2. Generate operation schedule based on arguments
/// 3. Run scheduled operations in sequence
template <typename OutWindow, typename AccTile, typename AuxWindows, typename... Args>
CK_TILE_DEVICE void operator()(OutWindow& out_window,
const AccTile& acc_tile,
const AuxWindows& aux_windows,
void* p_smem,
Args&&... args) const
{
// The context serves as a shared workspace that maintains intermediate results
// and resources across multiple epilogue operations.
auto context = Scheduler::create_context(out_window, acc_tile, aux_windows, p_smem);
auto schedule = Scheduler::make_schedule(std::forward<Args>(args)...);
schedule(out_window, acc_tile, aux_windows, p_smem, context);
}
};
/// @brief Epilogue operation wrapper with arguments
///
/// @par Purpose
/// EpilogueNode wraps individual epilogue operations with their required arguments,
/// allowing them to be composed into operation graphs. Arguments are captured at construction
/// time and automatically forwarded during processing.
///
/// @tparam EpilogueType Epilogue operation (e.g., SliceEpilogue, ScaleEpilogue)
/// @tparam Args Types of arguments required by the epilogue operation
template <typename EpilogueType, typename... Args>
struct EpilogueNode
{
using Epilogue = EpilogueType;
ck_tile::tuple<Args...> args;
constexpr EpilogueNode(Args... a) : args(a...) {}
/// @brief Process epilogue without iteration index
template <typename OutWindow, typename AccTile, typename AuxWindows, typename Context>
CK_TILE_DEVICE void operator()(OutWindow& out_window,
const AccTile& acc_tile,
const AuxWindows& aux_windows,
void* p_smem,
Context& context) const
{
ck_tile::apply(
[&](auto&&... epilogue_args) {
EpilogueType{}(out_window,
acc_tile,
aux_windows,
p_smem,
context,
std::forward<decltype(epilogue_args)>(epilogue_args)...);
},
args);
}
/// @brief Process epilogue with iteration index
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename Context,
index_t I>
CK_TILE_DEVICE void operator()(OutWindow& out_window,
const AccTile& acc_tile,
const AuxWindows& aux_windows,
void* p_smem,
Context& context,
number<I> iAccess) const
{
ck_tile::apply(
[&](auto&&... epilogue_args) {
EpilogueType{}(out_window,
acc_tile,
aux_windows,
p_smem,
iAccess,
context,
std::forward<decltype(epilogue_args)>(epilogue_args)...);
},
args);
}
};
/// @brief Specialization for epilogue operation wrapper with no arguments
template <typename EpilogueType>
struct EpilogueNode<EpilogueType>
{
using Epilogue = EpilogueType;
ck_tile::tuple<> args;
constexpr EpilogueNode() = default;
template <typename OutWindow, typename AccTile, typename AuxWindows, typename Context>
CK_TILE_DEVICE void operator()(OutWindow& out_window,
const AccTile& acc_tile,
const AuxWindows& aux_windows,
void* p_smem,
Context& context) const
{
EpilogueType{}(out_window, acc_tile, aux_windows, p_smem, context);
}
template <typename OutWindow,
typename AccTile,
typename AuxWindows,
typename Context,
index_t I>
CK_TILE_DEVICE void operator()(OutWindow& out_window,
const AccTile& acc_tile,
const AuxWindows& aux_windows,
void* p_smem,
Context& context,
number<I> iAccess) const
{
EpilogueType{}(out_window, acc_tile, aux_windows, p_smem, iAccess, context);
}
};
/// @brief Operation graph that sequentially composes multiple epilogue operations
///
/// @tparam Steps Number of iterations
/// @tparam EpilogueTypes Types of epilogue nodes in the operation graph
template <index_t Steps, typename... EpilogueTypes>
struct EpilogueGraph
{
ck_tile::tuple<EpilogueTypes...> epilogues;
constexpr EpilogueGraph() = default;
constexpr EpilogueGraph(EpilogueTypes... eps) : epilogues(eps...) {}
/// @brief Process all epilogues for each iteration in sequence
template <typename OutWindow, typename AccTile, typename AuxWindows, typename Context>
CK_TILE_DEVICE void operator()(OutWindow& out_window,
const AccTile& acc_tile,
const AuxWindows& aux_windows,
void* p_smem,
Context& context) const
{
// For each iteration, process all epilogues in order
static_for<0, Steps, 1>{}([&](auto iAccess) {
static_for<0, sizeof...(EpilogueTypes), 1>{}([&](auto I) {
epilogues.template get<I.value>()(
out_window, acc_tile, aux_windows, p_smem, context, iAccess);
});
});
}
};
/// @brief Helper function for creating epilogue nodes
template <typename EpilogueType, typename... Args>
constexpr auto make_node(Args... args)
{
return EpilogueNode<EpilogueType, Args...>{args...};
}
/// @brief Helper function for creating operation graphs
template <index_t Steps, typename... EpilogueTypes>
constexpr auto make_graph(EpilogueTypes... epilogues)
{
return EpilogueGraph<Steps, EpilogueTypes...>{epilogues...};
}
} // namespace ck_tile

View File

@@ -0,0 +1,879 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/host/concat.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/utils.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include <type_traits>
namespace ck_tile {
template <typename AsDataType_,
typename BsDataType_,
typename DsDataType_,
typename AccDataType_,
typename ODataType_,
typename DsLayout_,
typename ELayout_,
typename CDElementwise_,
index_t kM_,
index_t kN_,
index_t MWave_,
index_t NWave_,
index_t MPerXdl_,
index_t NPerXdl_,
index_t KPerXdl_,
bool isCTransposed_,
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false,
index_t BlockedXDLN_PerWarp_ = 1, // The number of continuous xdl_output per warp
bool DoubleSmemBuffer_ = false>
struct CShuffleEpilogueProblem
{
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t MWave = MWave_;
static constexpr index_t NWave = NWave_;
static constexpr index_t MPerXdl = MPerXdl_;
static constexpr index_t NPerXdl = NPerXdl_;
static constexpr index_t KPerXdl = KPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
};
template <typename Problem_, typename Policy_ = void>
struct CShuffleEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
remove_cvref_t<AsDataType>,
remove_cvref_t<tuple<AsDataType>>>;
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
remove_cvref_t<BsDataType>,
remove_cvref_t<tuple<BsDataType>>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
using ATypeToUse = std::conditional_t<std::is_same_v<ADataType, pk_int4_t> ||
std::is_same_v<ADataType, pk_fp4_t>,
BDataType,
ADataType>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_t> ||
sizeof(BDataType) < sizeof(ADataType),
ADataType,
BDataType>;
using ELayout = remove_cvref_t<typename Problem::ELayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t MWave = Problem::MWave;
static constexpr index_t NWave = Problem::NWave;
static constexpr index_t MPerXdl = Problem::MPerXdl;
static constexpr index_t NPerXdl = Problem::NPerXdl;
static constexpr index_t KPerXdl = Problem::KPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr bool EightWave = (MWave * NWave == 8);
static constexpr index_t BlockedXDLN_PerWarp =
EightWave ? kNPerBlock / NWave / NPerXdl : Problem::BlockedXDLN_PerWarp;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
CDElementwise elfunc_;
CK_TILE_DEVICE CShuffleEpilogue(CDElementwise elfunc = CDElementwise{}) : elfunc_(elfunc) {};
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "CShuffleEpilogue",
concat('x', MWave, NWave),
concat('x', MPerXdl, NPerXdl, KPerXdl),
VectorSizeC,
isCTransposed ? "CTransposed" : "CNotTransposed");
// clang-format on
}
/**
* @brief Get the vector store size for C tensor.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
{
if constexpr(FixedVectorSize)
{
return VectorSizeC;
}
constexpr index_t max_vector_size = 16;
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(NPerIteration),
static_cast<int>(max_vector_size / sizeof(ODataType)));
}
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(MPerIteration),
static_cast<int>(max_vector_size / sizeof(ODataType)));
}
else
{
static_assert(false, "Unsupported ELayout!");
}
}
/**
* @brief Get the vector store size for Di tensor.
*
* @return The vector store size for Di tensor.
*/
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> index)
{
constexpr index_t max_vector_size = 16;
using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(NPerIteration),
static_cast<int>(max_vector_size / sizeof(DiDataType)));
}
else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(MPerIteration),
static_cast<int>(max_vector_size / sizeof(DiDataType)));
}
else
{
static_assert(false, "Unsupported DLayout!");
}
return max_vector_size / sizeof(DiDataType);
}
/**
* @brief Shuffle tile configuration parameters check and aligment
*
* @details Return tuple(1, 1) if shuffle_tile values are too large for SMEM.
*/
template <index_t m_shuffle_tile, index_t n_shuffle_tile>
CK_TILE_HOST_DEVICE static constexpr auto AlignShuffleTileWithSmem()
{
constexpr index_t m_val = MPerXdl * MWave * m_shuffle_tile;
constexpr index_t n_val = NPerXdl * NWave * n_shuffle_tile;
constexpr auto shuffle_tile =
m_val * n_val * sizeof(ODataType) > get_smem_capacity() || DoubleSmemBuffer
? std::make_tuple(1, 1)
: std::make_tuple(m_shuffle_tile, n_shuffle_tile);
return shuffle_tile;
}
/**
* @brief Shuffle tile configuration parameters
*
* @details These parameters control the number of XDL tiles processed per wave in each shuffle
* iteration:
* - NumMXdlPerWavePerShuffle: Number of XDL tiles in M dimension processed per wave
* - NumNXdlPerWavePerShuffle: Number of XDL tiles in N dimension processed per wave
*/
static constexpr auto shuffle_tile_tuple = [] {
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
if constexpr(elem_per_thread <= GetVectorSizeC())
{
return std::make_tuple(1, 1);
}
else
{
constexpr index_t num_xdl_shuffles = elem_per_thread / GetVectorSizeC();
static_assert(elem_per_thread % GetVectorSizeC() == 0);
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
(kMPerBlock % num_xdl_shuffles == 0),
"kMPerBlock must be divisible by MPerXdl*MWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return AlignShuffleTileWithSmem<min(num_xdl_shuffles,
kMPerBlock / (MPerXdl * MWave)),
1>();
}
else
{
static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
(kNPerBlock % num_xdl_shuffles == 0),
"kNPerBlock must be divisible by NPerXdl*NWave and "
"num_xdl_shuffles for CShuffleEpilogue");
return AlignShuffleTileWithSmem<1,
min(num_xdl_shuffles,
kNPerBlock / (NPerXdl * NWave))>();
}
}
}();
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
static constexpr index_t NumNXdlPerWavePerShuffle =
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
static constexpr auto MNPerIterationShuffle = [] {
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
return std::make_tuple(MPerXdl * MWave, NPerXdl * NWave);
else
return std::make_tuple(m_val, n_val);
}();
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
using WG = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
AccDataType,
MPerXdl,
NPerXdl,
KPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{
constexpr auto DataTypeSize = sizeof(ODataType);
constexpr index_t VectorLen = GetVectorSizeC();
constexpr index_t banks = get_n_lds_banks();
constexpr index_t BytesPerBank = 4;
// N is contiguous dimension
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
constexpr index_t MLdsLayerRequired =
banks * BytesPerBank / NPerIterationShuffle / DataTypeSize;
constexpr auto MLdsLayer = max(1, MLdsLayerRequired);
constexpr index_t BaseStrideElems = NPerIterationShuffle * MLdsLayer;
static_assert((BaseStrideElems * DataTypeSize) % BytesPerBank == 0,
"LDS row stride must be 4B-aligned for bank-word padding logic");
// calculate how many elements to pad to avoid bank conflict
#if defined(__gfx950__)
constexpr index_t ElemsPer4B = BytesPerBank / ck_tile::gcd(BytesPerBank, DataTypeSize);
constexpr auto ToWords = [](index_t elems) constexpr {
return (elems * DataTypeSize) / BytesPerBank;
};
constexpr index_t BaseWords = ToWords(BaseStrideElems);
constexpr index_t PadWords = ((BaseWords % 2) == 0) ? 1 : 0;
constexpr auto PaddingAmount = PadWords * ElemsPer4B;
#else
constexpr auto PaddingAmount = 0;
#endif
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle / MLdsLayer>{},
number<NPerIterationShuffle / VectorLen * MLdsLayer>{},
number<VectorLen>{}),
make_tuple(number<NPerIterationShuffle * MLdsLayer + PaddingAmount>{},
number<VectorLen>{},
number<1>{}),
number<VectorLen>{},
number<1>{});
constexpr auto lds_block_desc_1 = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<MPerIterationShuffle / MLdsLayer>{}),
make_unmerge_transform(make_tuple(
number<MLdsLayer>{}, number<NPerIterationShuffle / VectorLen>{})),
make_pass_through_transform(number<VectorLen>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
constexpr auto lds_block_desc = transform_tensor_descriptor(
lds_block_desc_1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(
number<MPerIterationShuffle / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform_v3_division_mod(make_tuple(
number<NPerIterationShuffle / VectorLen>{}, number<VectorLen>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_block_desc;
}
// M is contiguous dimension
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t NLdsLayerRequired =
get_n_lds_banks() * BytesPerBank / MPerIterationShuffle / DataTypeSize;
constexpr auto NLdsLayer = max(1, NLdsLayerRequired);
constexpr index_t BaseStrideElems = MPerIterationShuffle * NLdsLayer;
static_assert((BaseStrideElems * DataTypeSize) % BytesPerBank == 0,
"LDS row stride must be 4B-aligned for bank-word padding logic");
#if defined(__gfx950__)
constexpr index_t ElemsPer4B = BytesPerBank / ck_tile::gcd(BytesPerBank, DataTypeSize);
constexpr auto ToWords = [](index_t elems) constexpr {
return (elems * DataTypeSize) / BytesPerBank;
};
constexpr index_t BaseWords = ToWords(BaseStrideElems);
constexpr index_t PadWords = ((BaseWords % 2) == 0) ? 1 : 0;
constexpr auto PaddingAmount = PadWords * ElemsPer4B;
#else
constexpr auto PaddingAmount = 0;
#endif
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NPerIterationShuffle / NLdsLayer>{},
number<MPerIterationShuffle / VectorLen * NLdsLayer>{},
number<VectorLen>{}),
make_tuple(number<MPerIterationShuffle * NLdsLayer + PaddingAmount>{},
number<VectorLen>{},
number<1>{}),
number<VectorLen>{},
number<1>{});
constexpr auto lds_block_desc_1 = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NPerIterationShuffle / NLdsLayer>{}),
make_unmerge_transform(make_tuple(
number<NLdsLayer>{}, number<MPerIterationShuffle / VectorLen>{})),
make_pass_through_transform(number<VectorLen>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
constexpr auto lds_block_desc = transform_tensor_descriptor(
lds_block_desc_1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(
number<NPerIterationShuffle / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform_v3_division_mod(make_tuple(
number<MPerIterationShuffle / VectorLen>{}, number<VectorLen>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_block_desc;
}
else
{
static_assert(false, "Unsupported ELayout!");
}
}
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
{
constexpr auto block_outer_dstr_encoding = [] {
if constexpr(BlockedXDLN_PerWarp == 1)
{
return tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
}
else
{
#if defined(__gfx950__)
constexpr auto is_950 = true;
#else
constexpr auto is_950 = false;
#endif
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
// BlockedLayout
// this branch is for original a16w4
if constexpr(is_950 || is_any_of<ADataType, pk_int4_t, pk_fp4_t>::value ||
is_any_of<BDataType, pk_int4_t, pk_fp4_t>::value)
{
if constexpr(EightWave)
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<2, 1>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
}
else
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
}
}
else
{
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, BlockedXDLN_PerWarp, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 2>>,
sequence<1, 2, 2>,
sequence<0, 0, 1>>{};
}
}
}();
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
return block_dstr_encoding;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
return lds_block_desc.get_element_space_size() * sizeof(ODataType);
}
template <index_t iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
CK_TILE_DEVICE void
scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window)
{
// Check if scales are EmptyScale first (no scaling needed)
if constexpr(std::is_same_v<ScaleM, EmptyScale> && std::is_same_v<ScaleN, EmptyScale>)
{
// No scaling needed - this is a no-op
}
// Check if scales are scalar AccDataType
else if constexpr(std::is_same_v<ScaleM, AccDataType> &&
std::is_same_v<ScaleN, AccDataType>)
{
// Handle scalar scales
const AccDataType scale_m = scale_m_window;
const AccDataType scale_n = scale_n_window;
tile_elementwise_inout([&](auto& element) { element = element * scale_m * scale_n; },
lds_tile);
}
// Otherwise, assume they are tile windows that can be loaded
else
{
// Load tiles
const auto scale_m_tile = load_tile(scale_m_window);
const auto scale_n_tile = load_tile(scale_n_window);
// Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
tile_elementwise_inout(
element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
// Move scale windows
constexpr index_t num_access = SFC::get_num_of_access();
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(number<iAccess>{});
move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
}
}
}
template <index_t iAccess, typename OAccTile, typename LdsTile>
CK_TILE_DEVICE void slice_acc_tile(const OAccTile& o_acc_tile, LdsTile& lds_tile)
{
constexpr auto idx_y_start = SFC::get_index(number<iAccess>{});
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter * NumMXdlPerWavePerShuffle, nIter * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
}
template <typename LdsTile, typename InLdsWindow>
CK_TILE_DEVICE void cast_lds_tile(LdsTile& lds_tile, InLdsWindow& in_lds_window)
{
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
store_tile(in_lds_window, c_warptile_in_tensor_casted);
}
template <typename DramWindows, typename COutTensor>
CK_TILE_DEVICE void apply_d_tensors(DramWindows& d_dram_windows, COutTensor& c_out_tensor)
{
const auto ds_tensor = generate_tuple(
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
number<NumDTensor>{}));
tile_elementwise_inout_unpack(elfunc_, c_ds_tiles);
}
template <typename OutDramWindow, typename COutTensor>
CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window,
const COutTensor& c_out_tensor)
{
if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
}
/**
* @brief Move both the output and D tensors windows for the next access.
*/
template <index_t iAccess, typename OutDramWindow, typename DDramWindows>
CK_TILE_DEVICE void move_windows(OutDramWindow& out_dram_window, DDramWindows& d_dram_windows)
{
constexpr index_t num_access = SFC::get_num_of_access();
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(number<iAccess>{});
// move the output dram window
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
// move windows for each of the D matrices (inputs for element-wise)
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx], {step.at(number<0>{}), step.at(number<1>{})});
});
}
}
// TODO: Check if there would be nicer ways to overload rather than with EmptyScale or nullptr_t
struct EmptyScale
{
};
template <typename, typename = void>
struct ScaleDataType
{
using DataType = float;
};
template <typename T>
struct ScaleDataType<T, std::void_t<typename T::DataType>>
{
using DataType = typename T::DataType;
};
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleM = EmptyScale,
typename ScaleN = EmptyScale,
int EnablePermuateN_ = TiledMMAPermuteN,
std::enable_if_t<EnablePermuateN_, int> = 0>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* /* p_smem */,
const ScaleM& scale_m = {},
const ScaleN& scale_n = {})
{
static constexpr int RowsPerLane = CWarpTensor::get_thread_buffer_size();
static_assert(MPerXdl % RowsPerLane == 0,
"CShuffle (permuteN): MPerXdl must be divisible by per-lane row count.");
constexpr int kM0 = MWave;
constexpr int kM2 = RowsPerLane;
constexpr int kM1 = MPerXdl / kM2;
constexpr int kN0 = NWave;
constexpr int kN1 = NPerXdl;
constexpr int kN2 = NRepeat;
using IntrThreadShuffleEncode =
tile_distribution_encoding<sequence<>,
tuple<sequence<kM0, kM1, kM2>, sequence<kN0, kN1, kN2>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>;
constexpr auto dram_tile_distribution =
make_static_tile_distribution(IntrThreadShuffleEncode{});
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto shuffle_acc = make_static_distributed_tensor<AccDataType>(dram_tile_distribution);
auto c_out_tensor = make_static_distributed_tensor<ODataType>(dram_tile_distribution);
// Optional scales (must share the same distribution to match per-thread indexing)
constexpr bool has_scales =
!std::is_same<ScaleM, EmptyScale>::value && !std::is_same<ScaleN, EmptyScale>::value;
constexpr bool has_scalar_scales =
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
// Tiles to hold row/col scales when present
using SMType = typename ScaleDataType<ScaleM>::DataType;
using SNType = typename ScaleDataType<ScaleN>::DataType;
auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
// Build windows only if non-scalar scales are provided
auto scale_m_window = [&]() {
if constexpr(has_scales && !has_scalar_scales)
{
return make_tile_window(scale_m, dram_tile_distribution);
}
else
{
return EmptyScale{};
}
}();
auto scale_n_window = [&]() {
if constexpr(has_scales && !has_scalar_scales)
{
return make_tile_window(scale_n, dram_tile_distribution);
}
else
{
return EmptyScale{};
}
}();
static_for<0, MRepeat, 1>{}([&](auto mIter) {
// Slice accumulators for this M repeat into the permuted layout
shuffle_acc.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
// If non-scalar scales provided, load them with identical distribution
if constexpr(has_scales && !has_scalar_scales)
{
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
}
// Pack 4 “rows per lane” as you already do
static_for<0, NRepeat, 1>{}([&](auto n_idx) {
// source indices in shuffle_acc: (n_idx * product(Y) + row)
const index_t plane = c_warp_y_lengths.product();
// local lambda to fuse scale (if present) and convert
static_for<0, kM2, 1>{}([&](auto m_lane) {
const int src = n_idx * plane + m_lane; // source row in this N-plane
const int dst = n_idx + m_lane * NRepeat; // permuted N layout in output
AccDataType v = shuffle_acc.get_thread_buffer()[src];
if constexpr(has_scalar_scales)
{
v = static_cast<AccDataType>(v * scale_m * scale_n);
}
else if constexpr(has_scales && !has_scalar_scales)
{
const auto sm = static_cast<float>(sm_tile.get_thread_buffer()[dst]);
const auto sn = static_cast<float>(sn_tile.get_thread_buffer()[dst]);
v = static_cast<AccDataType>(v * sm * sn);
}
c_out_tensor.get_thread_buffer()[dst] = type_convert<ODataType>(v);
});
});
// store/update
if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
else
{
update_tile(out_dram_window, c_out_tensor);
}
// advance output (and any D-tensors) by one MPerXdl*MWave chunk
move_tile_window(out_dram_window, {number<MPerXdl * MWave>{}, number<0>{}});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx], {number<MPerXdl * MWave>{}, number<0>{}});
});
});
}
template <typename ODramWindow,
typename OAccTile,
typename DsDramWindows,
typename ScaleM = EmptyScale,
typename ScaleN = EmptyScale,
int EnablePermuateN_ = TiledMMAPermuteN,
std::enable_if_t<!EnablePermuateN_, int> = 0>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem,
const ScaleM& scale_m = {},
const ScaleN& scale_n = {})
{
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
static_cast<ODataType*>(p_smem), lds_block_desc);
auto in_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0},
LdsTileDistr);
auto out_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
{0, 0});
constexpr index_t num_access = SFC::get_num_of_access();
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<kBlockSize,
MPerIterationShuffle,
NPerIterationShuffle,
GetVectorSizeC(),
tile_distribution_pattern::thread_raked,
Problem::kNumWaveGroups>;
constexpr auto dram_tile_distribution =
TileEncodingPattern::make_2d_static_tile_distribution();
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});
constexpr bool has_scales =
!std::is_same_v<ScaleM, EmptyScale> && !std::is_same_v<ScaleN, EmptyScale>;
constexpr bool has_scalar_scales =
std::is_same_v<ScaleM, AccDataType> && std::is_same_v<ScaleN, AccDataType>;
auto scale_m_window = [&]() {
if constexpr(has_scalar_scales)
{
return scale_m;
}
else if constexpr(has_scales)
{
return make_tile_window(scale_m, lds_tile.get_tile_distribution());
}
else
{
return EmptyScale{};
}
}();
auto scale_n_window = [&]() {
if constexpr(has_scalar_scales)
{
return scale_n;
}
else if constexpr(has_scales)
{
return make_tile_window(scale_n, lds_tile.get_tile_distribution());
}
else
{
return EmptyScale{};
}
}();
static_for<0, num_access, 1>{}([&](auto iAccess) {
block_sync_lds();
slice_acc_tile<iAccess>(o_acc_tile, lds_tile);
if constexpr(has_scales)
{
scale_tile<iAccess>(lds_tile, scale_m_window, scale_n_window);
}
cast_lds_tile(lds_tile, in_lds_window);
block_sync_lds();
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
apply_d_tensors(d_dram_windows, c_out_tensor);
store_to_dram(out_dram_window, c_out_tensor);
move_windows<iAccess>(out_dram_window, d_dram_windows);
});
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,91 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "default_2d_epilogue.hpp"
#include "dynamic_quant_epilogue.hpp"
namespace ck_tile {
// User can reuse DynamicQuantEpilogueTraits with this epilogue
template <bool kPadM_,
bool kPadN_,
bool UseSmoothInputScale_,
bool UseRawStore_ = true,
bool UseMax3_ = false>
using Default2DAndDynamicQuantEpilogueTraits =
DynamicQuantEpilogueTraits<kPadM_, kPadN_, UseSmoothInputScale_, UseRawStore_, UseMax3_>;
// This epilogue just store out a M*N matrix, row major
template <typename AccDataType_,
typename SmoothScaleDataType_,
typename YScaleDataType_,
typename ODataType_,
typename UnquantYDataType_,
typename BlockShape_,
typename Traits_>
struct Default2DAndDynamicQuantEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using UnquantYDataType = remove_cvref_t<UnquantYDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
using Traits = remove_cvref_t<Traits_>;
};
template <typename Problem_, typename Policy_ = void>
struct Default2DAndDynamicQuantEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using UnquantYDataType = remove_cvref_t<typename Problem::UnquantYDataType>;
static constexpr bool kPadM = Problem::Traits::kPadM;
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool UseRawStore = Problem::Traits::UseRawStore;
using Default2DProblem =
Default2DEpilogueProblem<AccDataType, UnquantYDataType, kPadM, kPadN, UseRawStore>;
using Default2D = Default2DEpilogue<Default2DProblem>;
using DynamicQuant = DynamicQuantEpilogue<Problem>;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return max(Default2D::GetSmemSize(), DynamicQuant::GetSmemSize());
}
template <typename ODramWindowTmpD,
typename ODramWindowTmpQ,
typename SmoothScaleWindow,
typename YScaleWindow,
typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmpD& o_direct_dram_window_tmp,
ODramWindowTmpQ& o_quant_dram_window_tmp,
const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile,
void* smem)
{
Default2D{}(o_direct_dram_window_tmp, o_acc_tile, smem);
DynamicQuant{}(o_quant_dram_window_tmp, sm_scale_window_, y_scale_window, o_acc_tile, smem);
}
template <typename ODramWindowTmpD,
typename ODramWindowTmpQ,
typename YScaleWindow,
typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmpD& o_direct_dram_window_tmp,
ODramWindowTmpQ& o_quant_dram_window_tmp,
YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile,
void* smem)
{
Default2D{}(o_direct_dram_window_tmp, o_acc_tile, smem);
DynamicQuant{}(o_quant_dram_window_tmp, y_scale_window, o_acc_tile, smem);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,291 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile {
// this epilogue just store out a M*N matrix, row major
template <typename AccDataType_,
typename ODataType_,
bool kPadM_,
bool kPadN_,
bool UseRawStore_ = true>
struct Default2DEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_;
static constexpr index_t NumDTensor = 0;
};
template <typename AsDataType_,
typename BsDataType_,
typename DsDataType_,
typename AccDataType_,
typename ODataType_,
typename DsLayout_,
typename CLayout_,
typename CDElementwise_,
index_t kM_,
index_t kN_,
bool kPadM_,
bool kPadN_,
index_t kMPerXdl_,
index_t kNPerXdl_,
index_t kKPerXdl_,
bool isCTransposed_,
bool UseRawStore_ = true>
struct DefaultGemm2DEpilogueProblem
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
{
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using CLayout = remove_cvref_t<CLayout_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
using DsLayout = remove_cvref_t<DsLayout_>;
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t kMPerXdl = kMPerXdl_;
static constexpr index_t kNPerXdl = kNPerXdl_;
static constexpr index_t kKPerXdl = kKPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr index_t NumDTensor = DsDataType::size();
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
};
template <typename Problem_, typename Policy_ = void>
struct Default2DEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseRawStore = Problem::UseRawStore;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* = nullptr) const
{
constexpr bool is_partition_index =
std::is_convertible_v<decltype(ds_dram_windows),
decltype(get_partition_index(
o_acc_tile.get_tile_distribution()))>;
const auto storeOrUpdateTile = [&](const auto& o_tile) {
// TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN))
{
// FIXME?
// if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp ==
// memory_operation_enum::set)
if constexpr(true)
{
if constexpr(is_partition_index)
{
store_tile_raw(o_dram_window_tmp,
cast_tile<ODataType>(o_tile),
/*partition_index=*/ds_dram_windows);
}
else
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
}
}
else
{
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
}
buffer_store_fence();
}
else
{
// FIXME?
// if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp ==
// memory_operation_enum::set)
if constexpr(true)
{
if constexpr(is_partition_index)
{
store_tile(o_dram_window_tmp,
cast_tile<ODataType>(o_tile),
/*partition_index=*/ds_dram_windows);
}
else
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
}
}
else
{
if constexpr(is_partition_index)
{
update_tile(o_dram_window_tmp,
cast_tile<ODataType>(o_tile),
/*partition_index=*/ds_dram_windows);
}
else
{
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
}
}
}
};
if constexpr(!std::is_same_v<DsDramWindows, std::nullptr_t> && !is_partition_index &&
Problem::NumDTensor >= 1)
{
using elementwise_result_t = decltype(load_tile(
make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(),
make_tuple(Problem::kMPerBlock, Problem::kNPerBlock),
ds_dram_windows[number<0>{}].get_window_origin(),
o_acc_tile.get_tile_distribution())));
elementwise_result_t elementwise_result;
const auto d_tensor_tuple = generate_tuple(
[&](auto idx) {
const auto d_tile_window =
make_tile_window(ds_dram_windows[idx], o_acc_tile.get_tile_distribution());
return load_tile(d_tile_window);
},
number<Problem::NumDTensor>{});
const auto c_d_tuple = concat_tuple_of_reference(
tie(elementwise_result, o_acc_tile),
generate_tie([&](auto idx) -> const auto& { return d_tensor_tuple[idx]; },
number<Problem::NumDTensor>{}));
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_d_tuple);
storeOrUpdateTile(elementwise_result);
}
else
{
storeOrUpdateTile(o_acc_tile);
}
}
};
template <typename Problem_, typename Policy_ = void>
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
{
using Problem = remove_cvref_t<Problem_>;
using AsDataType = remove_cvref_t<typename Problem::AsDataType>;
using BsDataType = remove_cvref_t<typename Problem::BsDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
remove_cvref_t<AsDataType>,
remove_cvref_t<tuple<AsDataType>>>;
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
remove_cvref_t<BsDataType>,
remove_cvref_t<tuple<BsDataType>>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
using WG = WarpGemmDispatcher<ADataType,
BTypeToUse,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(isCTransposed)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
{
// In this case each thread has just a single item in Ndim
return (WG::WarpGemmAttribute::Impl::kCNLane *
WG::WarpGemmAttribute::Impl::kBNBlock) /
WG::kN;
}
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(isCTransposed)
{
// In this case each thread has just a single item in Mdim
return (WG::WarpGemmAttribute::Impl::kCNLane *
WG::WarpGemmAttribute::Impl::kAMBlock) /
WG::kN;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD([[maybe_unused]] number<I> index)
{
return GetVectorSizeC();
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,212 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
namespace ck_tile {
template <bool kPadM_,
bool kPadN_,
bool UseSmoothInputScale_,
bool UseRawStore_ = true,
bool UseMax3_ = false>
struct DynamicQuantEpilogueTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool UseSmoothInputScale = UseSmoothInputScale_;
static constexpr bool UseRawStore = UseRawStore_;
static constexpr bool UseMax3 = UseMax3_;
};
// this epilogue just store out a M*N matrix, row major
template <typename AccDataType_,
typename SmoothScaleDataType_,
typename YScaleDataType_,
typename ODataType_,
typename BlockShape_,
typename Traits_>
struct DynamicQuantEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using SmoothScaleDataType = remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
using Traits = remove_cvref_t<Traits_>;
};
// TODO: we should put descriptor creation function into policy
template <typename Problem_, typename Policy_ = void>
struct DynamicQuantEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using SmoothScaleDataType = remove_cvref_t<typename Problem::SmoothScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using BlockShape = remove_cvref_t<typename Problem::BlockShape>;
static constexpr bool kPadM = Problem::Traits::kPadM;
static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool UseRawStore = Problem::Traits::UseRawStore;
static constexpr bool UseMax3 = Problem::Traits::UseMax3;
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d()
{
using P_ = BlockReduce2dProblem<AccDataType, AccDataType, BlockShape>;
return BlockReduce2d<P_>{};
}
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync()
{
using P_ = BlockReduce2dProblem<AccDataType, AccDataType, BlockShape>;
return BlockReduce2dSync<P_>{};
}
CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync()
{
using P_ = BlockReduce2dProblem<AccDataType, AccDataType, BlockShape>;
return BlockReduce2dCrossWarpSync<P_>{};
}
CK_TILE_DEVICE static constexpr auto MakeSmoothInputScaleTileDistribution()
{
using S = BlockShape;
#if 0
// don't remove this
// Note that if we set encoding purposely like this, you will result in compile fail
// TODO: sm_scale create local-scratch to accept arbitrary acc input (with same length)
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<0, 1, 1>,
sequence<0, 0, 3>>{});
#else
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 1>,
sequence<0, 3>>{});
#endif
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
return reduce_crosswarp_sync.GetSmemSize();
}
template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
CK_TILE_DEVICE auto Impl(ODramWindowTmp& o_dram_window_tmp,
YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile,
void* smem)
{
auto reduce = GetBlockReduce2d();
auto reduce_sync = GetBlockReduce2dSync();
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
auto o_acc_tmp = o_acc_tile;
const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
auto row_absmax = [&]() {
constexpr auto y_size_per_row =
OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(
number<1>{});
if constexpr(UseMax3 && std::is_same_v<AccDataType, float> && y_size_per_row % 2 == 0)
{
// fast max3+abs implementation
const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) {
float rtn;
asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
: "=v"(rtn)
: "v"(acc_), "v"(v_0_), "v"(v_1_));
return rtn;
};
return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{});
}
else
{
return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_absmax);
}
}();
reduce_sync(row_absmax, f_absmax);
reduce_crosswarp_sync(row_absmax, smem, f_absmax);
// here y_scale is Acc TYpe, need convert to YScale type later
auto y_scale = tile_elementwise_in(
[&](const auto& v_) {
return v_ / type_convert<AccDataType>(numeric<ODataType>::max());
},
row_absmax);
store_tile(y_scale_window, cast_tile<YScaleDataType>(y_scale));
sweep_tile(o_acc_tmp, [&](auto idx) {
constexpr auto row_id = make_tuple(idx[number<0>{}]);
o_acc_tmp(idx) = o_acc_tmp[idx] / y_scale(row_id);
});
// TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN))
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
buffer_store_fence();
}
else
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
}
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
// Smooth Dynamic Quant
template <typename ODramWindowTmp,
typename SmoothScaleWindow,
typename YScaleWindow,
typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const SmoothScaleWindow& sm_scale_window_,
YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile,
void* smem)
{
const auto sm_scale_window =
make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution());
auto sm_scale = load_tile(sm_scale_window);
auto o_acc_tmp = o_acc_tile;
sweep_tile(o_acc_tmp, [&](auto idx) {
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto xs_ = type_convert<AccDataType>(sm_scale[j_idx]);
o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
});
Impl(o_dram_window_tmp, y_scale_window, o_acc_tmp, smem);
}
// Dynamic Quant
template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile,
void* smem)
{
Impl(o_dram_window_tmp, y_scale_window, o_acc_tile, smem);
}
};
} // namespace ck_tile