Impove precision of CShuffle with scales or multi D

Two new template parameters are introduced:
 * CShuffleDataType allows to apply multiple Ds before downcasting to
 ODataType (prevents unexpected precision loss and/or overflow);
 * CComputeDataType to use scales with int32 AccDataType (int8 gemms);
This commit is contained in:
Anton Gorenko
2026-06-16 11:41:20 +05:00
parent 1b649a8d4b
commit 335f80033b

View File

@@ -39,7 +39,9 @@ template <typename AsDataType_,
bool DoubleSmemBuffer_ = false,
typename AComputeDataType_ = void,
typename BComputeDataType_ = void,
bool TilesPacked_ = false>
bool TilesPacked_ = false,
typename CComputeDataType_ = AccDataType_,
typename CShuffleDataType_ = ODataType_>
struct CShuffleEpilogueProblem
{
using AsDataType = remove_cvref_t<AsDataType_>;
@@ -52,6 +54,8 @@ struct CShuffleEpilogueProblem
using CDElementwise = remove_cvref_t<CDElementwise_>;
using AComputeDataType = remove_cvref_t<AComputeDataType_>;
using BComputeDataType = remove_cvref_t<BComputeDataType_>;
using CComputeDataType = remove_cvref_t<CComputeDataType_>;
using CShuffleDataType = remove_cvref_t<CShuffleDataType_>;
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
@@ -84,6 +88,8 @@ struct CShuffleEpilogue
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
using AComputeDataType = remove_cvref_t<typename Problem::AComputeDataType>;
using BComputeDataType = remove_cvref_t<typename Problem::BComputeDataType>;
using CComputeDataType = remove_cvref_t<typename Problem::CComputeDataType>;
using CShuffleDataType = remove_cvref_t<typename Problem::CShuffleDataType>;
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
@@ -236,7 +242,7 @@ struct CShuffleEpilogue
constexpr index_t n_val = NPerXdl * NWave * n_shuffle_tile;
constexpr auto shuffle_tile =
m_val * n_val * sizeof(ODataType) > get_smem_capacity() || DoubleSmemBuffer
m_val * n_val * sizeof(CShuffleDataType) > get_smem_capacity() || DoubleSmemBuffer
? std::make_tuple(1, 1)
: std::make_tuple(m_shuffle_tile, n_shuffle_tile);
@@ -317,7 +323,7 @@ struct CShuffleEpilogue
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{
constexpr auto DataTypeSize = sizeof(ODataType);
constexpr auto DataTypeSize = sizeof(CShuffleDataType);
constexpr index_t VectorLen = GetVectorSizeC();
constexpr index_t banks = get_n_lds_banks();
@@ -550,7 +556,7 @@ struct CShuffleEpilogue
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
return lds_block_desc.get_element_space_size() * sizeof(ODataType);
return lds_block_desc.get_element_space_size() * sizeof(CShuffleDataType);
}
/// Number of block_sync_lds() calls in operator().
@@ -581,13 +587,14 @@ struct CShuffleEpilogue
}
template <index_t iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
CK_TILE_DEVICE void
CK_TILE_DEVICE auto
scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window)
{
// Check if scales are EmptyScale first (no scaling needed)
if constexpr(std::is_same_v<ScaleM, EmptyScale> && std::is_same_v<ScaleN, EmptyScale>)
{
// No scaling needed - this is a no-op
return lds_tile;
}
// Check if scales are scalar AccDataType
else if constexpr(std::is_same_v<ScaleM, AccDataType> &&
@@ -598,6 +605,7 @@ struct CShuffleEpilogue
const AccDataType scale_n = scale_n_window;
tile_elementwise_inout([&](auto& element) { element = element * scale_m * scale_n; },
lds_tile);
return lds_tile;
}
// Otherwise, assume they are tile windows that can be loaded
else
@@ -606,9 +614,15 @@ struct CShuffleEpilogue
const auto scale_m_tile = load_tile(scale_m_window);
const auto scale_n_tile = load_tile(scale_n_window);
// Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
tile_elementwise_inout(
element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
auto scaled_lds_tile =
make_static_distributed_tensor<CComputeDataType>(lds_tile.get_tile_distribution());
// Compute element-wise product i.e. scaled_lds_tile = lds_tile * scale_m * scale_n
tile_elementwise_inout(element_wise::MultiDMultiply{},
scaled_lds_tile,
lds_tile,
scale_m_tile,
scale_n_tile);
// Move scale windows
constexpr index_t num_access = SFC::get_num_of_access();
@@ -619,6 +633,8 @@ struct CShuffleEpilogue
move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
}
return scaled_lds_tile;
}
}
@@ -641,14 +657,6 @@ struct CShuffleEpilogue
c_warp_y_lengths));
}
template <typename LdsTile, typename InLdsWindow>
CK_TILE_DEVICE void cast_lds_tile(LdsTile& lds_tile, InLdsWindow& in_lds_window)
{
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
store_tile(in_lds_window, c_warptile_in_tensor_casted);
}
template <typename DramWindows, typename COutTensor>
CK_TILE_DEVICE void apply_d_tensors(DramWindows& d_dram_windows, COutTensor& c_out_tensor)
{
@@ -734,7 +742,7 @@ struct CShuffleEpilogue
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
static_cast<ODataType*>(p_smem), lds_block_desc);
static_cast<CShuffleDataType*>(p_smem), lds_block_desc);
auto in_lds_window = make_tile_window(
o_lds_block,
@@ -810,18 +818,16 @@ struct CShuffleEpilogue
block_sync_lds();
slice_acc_tile<iAccess>(o_acc_tile, lds_tile);
if constexpr(has_scales)
{
const auto scaled_lds_tile =
scale_tile<iAccess>(lds_tile, scale_m_window, scale_n_window);
}
cast_lds_tile(lds_tile, in_lds_window);
store_tile(in_lds_window, cast_tile<CShuffleDataType>(scaled_lds_tile));
block_sync_lds();
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
apply_d_tensors(d_dram_windows, c_out_tensor);
store_to_dram(out_dram_window, c_out_tensor);
store_to_dram(out_dram_window, cast_tile<ODataType>(c_out_tensor));
move_windows<iAccess>(out_dram_window, d_dram_windows);
});
}