mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 04:37:02 +00:00
Impove precision of CShuffle with scales or multi D
Two new template parameters are introduced: * CShuffleDataType allows to apply multiple Ds before downcasting to ODataType (prevents unexpected precision loss and/or overflow); * CComputeDataType to use scales with int32 AccDataType (int8 gemms);
This commit is contained in:
@@ -39,7 +39,9 @@ template <typename AsDataType_,
|
||||
bool DoubleSmemBuffer_ = false,
|
||||
typename AComputeDataType_ = void,
|
||||
typename BComputeDataType_ = void,
|
||||
bool TilesPacked_ = false>
|
||||
bool TilesPacked_ = false,
|
||||
typename CComputeDataType_ = AccDataType_,
|
||||
typename CShuffleDataType_ = ODataType_>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
@@ -52,6 +54,8 @@ struct CShuffleEpilogueProblem
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
using AComputeDataType = remove_cvref_t<AComputeDataType_>;
|
||||
using BComputeDataType = remove_cvref_t<BComputeDataType_>;
|
||||
using CComputeDataType = remove_cvref_t<CComputeDataType_>;
|
||||
using CShuffleDataType = remove_cvref_t<CShuffleDataType_>;
|
||||
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
@@ -84,6 +88,8 @@ struct CShuffleEpilogue
|
||||
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
|
||||
using AComputeDataType = remove_cvref_t<typename Problem::AComputeDataType>;
|
||||
using BComputeDataType = remove_cvref_t<typename Problem::BComputeDataType>;
|
||||
using CComputeDataType = remove_cvref_t<typename Problem::CComputeDataType>;
|
||||
using CShuffleDataType = remove_cvref_t<typename Problem::CShuffleDataType>;
|
||||
|
||||
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
|
||||
@@ -236,7 +242,7 @@ struct CShuffleEpilogue
|
||||
constexpr index_t n_val = NPerXdl * NWave * n_shuffle_tile;
|
||||
|
||||
constexpr auto shuffle_tile =
|
||||
m_val * n_val * sizeof(ODataType) > get_smem_capacity() || DoubleSmemBuffer
|
||||
m_val * n_val * sizeof(CShuffleDataType) > get_smem_capacity() || DoubleSmemBuffer
|
||||
? std::make_tuple(1, 1)
|
||||
: std::make_tuple(m_shuffle_tile, n_shuffle_tile);
|
||||
|
||||
@@ -317,7 +323,7 @@ struct CShuffleEpilogue
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
|
||||
{
|
||||
constexpr auto DataTypeSize = sizeof(ODataType);
|
||||
constexpr auto DataTypeSize = sizeof(CShuffleDataType);
|
||||
constexpr index_t VectorLen = GetVectorSizeC();
|
||||
constexpr index_t banks = get_n_lds_banks();
|
||||
|
||||
@@ -550,7 +556,7 @@ struct CShuffleEpilogue
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
return lds_block_desc.get_element_space_size() * sizeof(ODataType);
|
||||
return lds_block_desc.get_element_space_size() * sizeof(CShuffleDataType);
|
||||
}
|
||||
|
||||
/// Number of block_sync_lds() calls in operator().
|
||||
@@ -581,13 +587,14 @@ struct CShuffleEpilogue
|
||||
}
|
||||
|
||||
template <index_t iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
|
||||
CK_TILE_DEVICE void
|
||||
CK_TILE_DEVICE auto
|
||||
scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window)
|
||||
{
|
||||
// Check if scales are EmptyScale first (no scaling needed)
|
||||
if constexpr(std::is_same_v<ScaleM, EmptyScale> && std::is_same_v<ScaleN, EmptyScale>)
|
||||
{
|
||||
// No scaling needed - this is a no-op
|
||||
return lds_tile;
|
||||
}
|
||||
// Check if scales are scalar AccDataType
|
||||
else if constexpr(std::is_same_v<ScaleM, AccDataType> &&
|
||||
@@ -598,6 +605,7 @@ struct CShuffleEpilogue
|
||||
const AccDataType scale_n = scale_n_window;
|
||||
tile_elementwise_inout([&](auto& element) { element = element * scale_m * scale_n; },
|
||||
lds_tile);
|
||||
return lds_tile;
|
||||
}
|
||||
// Otherwise, assume they are tile windows that can be loaded
|
||||
else
|
||||
@@ -606,9 +614,15 @@ struct CShuffleEpilogue
|
||||
const auto scale_m_tile = load_tile(scale_m_window);
|
||||
const auto scale_n_tile = load_tile(scale_n_window);
|
||||
|
||||
// Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
|
||||
tile_elementwise_inout(
|
||||
element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
|
||||
auto scaled_lds_tile =
|
||||
make_static_distributed_tensor<CComputeDataType>(lds_tile.get_tile_distribution());
|
||||
|
||||
// Compute element-wise product i.e. scaled_lds_tile = lds_tile * scale_m * scale_n
|
||||
tile_elementwise_inout(element_wise::MultiDMultiply{},
|
||||
scaled_lds_tile,
|
||||
lds_tile,
|
||||
scale_m_tile,
|
||||
scale_n_tile);
|
||||
|
||||
// Move scale windows
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
@@ -619,6 +633,8 @@ struct CShuffleEpilogue
|
||||
move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
}
|
||||
|
||||
return scaled_lds_tile;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -641,14 +657,6 @@ struct CShuffleEpilogue
|
||||
c_warp_y_lengths));
|
||||
}
|
||||
|
||||
template <typename LdsTile, typename InLdsWindow>
|
||||
CK_TILE_DEVICE void cast_lds_tile(LdsTile& lds_tile, InLdsWindow& in_lds_window)
|
||||
{
|
||||
const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
|
||||
|
||||
store_tile(in_lds_window, c_warptile_in_tensor_casted);
|
||||
}
|
||||
|
||||
template <typename DramWindows, typename COutTensor>
|
||||
CK_TILE_DEVICE void apply_d_tensors(DramWindows& d_dram_windows, COutTensor& c_out_tensor)
|
||||
{
|
||||
@@ -734,7 +742,7 @@ struct CShuffleEpilogue
|
||||
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<ODataType*>(p_smem), lds_block_desc);
|
||||
static_cast<CShuffleDataType*>(p_smem), lds_block_desc);
|
||||
|
||||
auto in_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
@@ -810,18 +818,16 @@ struct CShuffleEpilogue
|
||||
block_sync_lds();
|
||||
slice_acc_tile<iAccess>(o_acc_tile, lds_tile);
|
||||
|
||||
if constexpr(has_scales)
|
||||
{
|
||||
const auto scaled_lds_tile =
|
||||
scale_tile<iAccess>(lds_tile, scale_m_window, scale_n_window);
|
||||
}
|
||||
|
||||
cast_lds_tile(lds_tile, in_lds_window);
|
||||
store_tile(in_lds_window, cast_tile<CShuffleDataType>(scaled_lds_tile));
|
||||
block_sync_lds();
|
||||
|
||||
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
|
||||
apply_d_tensors(d_dram_windows, c_out_tensor);
|
||||
store_to_dram(out_dram_window, c_out_tensor);
|
||||
store_to_dram(out_dram_window, cast_tile<ODataType>(c_out_tensor));
|
||||
move_windows<iAccess>(out_dram_window, d_dram_windows);
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user