mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
Shuffle fix for gfx950 (#3491)
* solve compiler issue * solve the gfx950 mfma shuffle regression * refactor jenkinsfile to handle arch name better * [CK TILE] set divisor to count of thread along k dimension * fix the compiler error * solve degradation * Finish the multiplies fix * fix the scales * solve compilation error * solve the composes * solve the error of tile sweeper * fix the test and example * fix for gfx950 --------- Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com> Co-authored-by: Cong Ma <congma13@amd.com>
This commit is contained in:
@@ -564,7 +564,7 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
|
||||
using UpperIndex = multi_index<1>;
|
||||
|
||||
using UpLengths =
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{})));
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, multiplies<>{}, number<1>{})));
|
||||
|
||||
using LowLengthsMagicDivisor = decltype(generate_tuple(
|
||||
lambda_merge_generate_MagicDivision_calculate_magic_divisor<LowLengths>{},
|
||||
@@ -584,7 +584,7 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
|
||||
low_lengths_magic_divisor_{generate_tuple(
|
||||
[&](auto i) { return magic_division::calculate_magic_numbers(low_lengths[i]); },
|
||||
number<NDimLow>{})},
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, I1))}
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, multiplies<>{}, I1))}
|
||||
{
|
||||
static_assert(LowerIndex::size() == NDimLow, "wrong!");
|
||||
}
|
||||
@@ -707,10 +707,10 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
|
||||
using UpperIndex = multi_index<1>;
|
||||
|
||||
using LowLengthsScan =
|
||||
decltype(container_reverse_exclusive_scan(LowLengths{}, multiplies{}, number<1>{}));
|
||||
decltype(container_reverse_exclusive_scan(LowLengths{}, multiplies<>{}, number<1>{}));
|
||||
|
||||
using UpLengths =
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{})));
|
||||
decltype(make_tuple(container_reduce(LowLengths{}, multiplies<>{}, number<1>{})));
|
||||
|
||||
LowLengths low_lengths_;
|
||||
LowLengthsScan low_lengths_scan_;
|
||||
@@ -721,8 +721,8 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
|
||||
CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod(const LowLengths& low_lengths)
|
||||
: low_lengths_{low_lengths},
|
||||
low_lengths_scan_{
|
||||
container_reverse_exclusive_scan(low_lengths, multiplies{}, number<1>{})},
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, number<1>{}))}
|
||||
container_reverse_exclusive_scan(low_lengths, multiplies<>{}, number<1>{})},
|
||||
up_lengths_{make_tuple(container_reduce(low_lengths, multiplies<>{}, number<1>{}))}
|
||||
{
|
||||
static_assert(LowerIndex::size() == NDimLow, "wrong!");
|
||||
}
|
||||
@@ -832,7 +832,7 @@ struct unmerge : public base_transform<1, UpLengths::size()>
|
||||
using UpperIndex = multi_index<NDimUp>;
|
||||
|
||||
using UpLengthsScan =
|
||||
decltype(container_reverse_exclusive_scan(UpLengths{}, multiplies{}, number<1>{}));
|
||||
decltype(container_reverse_exclusive_scan(UpLengths{}, multiplies<>{}, number<1>{}));
|
||||
|
||||
UpLengths up_lengths_;
|
||||
UpLengthsScan up_lengths_scan_;
|
||||
@@ -841,7 +841,8 @@ struct unmerge : public base_transform<1, UpLengths::size()>
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr unmerge(const UpLengths& up_lengths)
|
||||
: up_lengths_{up_lengths},
|
||||
up_lengths_scan_{container_reverse_exclusive_scan(up_lengths, multiplies{}, number<1>{})}
|
||||
up_lengths_scan_{
|
||||
container_reverse_exclusive_scan(up_lengths, multiplies<>{}, number<1>{})}
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ template <typename TensorLengths,
|
||||
struct space_filling_curve
|
||||
{
|
||||
static constexpr index_t TensorSize =
|
||||
reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{});
|
||||
reduce_on_sequence(TensorLengths{}, multiplies<>{}, number<1>{});
|
||||
static_assert(0 < TensorSize,
|
||||
"space_filling_curve should be used to access a non-empty tensor");
|
||||
|
||||
@@ -28,7 +28,7 @@ struct space_filling_curve
|
||||
using Index = multi_index<nDim>;
|
||||
|
||||
static constexpr index_t ScalarPerVector =
|
||||
reduce_on_sequence(ScalarsPerAccess{}, multiplies{}, number<1>{});
|
||||
reduce_on_sequence(ScalarsPerAccess{}, multiplies<>{}, number<1>{});
|
||||
|
||||
static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{};
|
||||
static constexpr auto dim_access_order = DimAccessOrder{};
|
||||
@@ -49,7 +49,7 @@ struct space_filling_curve
|
||||
static_assert(TensorLengths{} % ScalarsPerAccess{} ==
|
||||
typename uniform_sequence_gen<TensorLengths::size(), 0>::type{});
|
||||
|
||||
return reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{}) / ScalarPerVector;
|
||||
return reduce_on_sequence(TensorLengths{}, multiplies<>{}, number<1>{}) / ScalarPerVector;
|
||||
}
|
||||
|
||||
template <index_t AccessIdx1dHead, index_t AccessIdx1dTail>
|
||||
@@ -94,7 +94,7 @@ struct space_filling_curve
|
||||
#else
|
||||
|
||||
constexpr auto access_strides =
|
||||
container_reverse_exclusive_scan(ordered_access_lengths, multiplies{}, number<1>{});
|
||||
container_reverse_exclusive_scan(ordered_access_lengths, multiplies<>{}, number<1>{});
|
||||
|
||||
constexpr auto idx_1d = number<AccessIdx1d>{};
|
||||
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
|
||||
|
||||
@@ -1237,10 +1237,11 @@ constexpr auto reverse_slice_sequence(Seq,
|
||||
{
|
||||
static_assert(Seq::size() == Mask::size());
|
||||
static_assert(SliceSize != 0, "slice size zero is invalid");
|
||||
static_assert(container_reduce(pick_sequence_elements_by_mask(Seq{}, Mask{}), multiplies{}, 1) %
|
||||
SliceSize ==
|
||||
0,
|
||||
"slice size can't evenly divide input sizes");
|
||||
static_assert(
|
||||
container_reduce(pick_sequence_elements_by_mask(Seq{}, Mask{}), multiplies<>{}, 1) %
|
||||
SliceSize ==
|
||||
0,
|
||||
"slice size can't evenly divide input sizes");
|
||||
using sliced_type =
|
||||
impl::reverse_slice_sequence_impl<Seq,
|
||||
Mask,
|
||||
|
||||
@@ -42,7 +42,7 @@ struct scales
|
||||
};
|
||||
|
||||
template <typename Scale>
|
||||
CK_TILE_HOST_DEVICE_EXTERN scales(Scale) -> scales<Scale>;
|
||||
scales(Scale) -> scales<Scale>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct plus
|
||||
@@ -65,8 +65,6 @@ struct plus<void, void>
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE_EXTERN plus() -> plus<void, void>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct minus
|
||||
{
|
||||
@@ -88,8 +86,6 @@ struct minus<void, void>
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE_EXTERN minus() -> minus<void, void>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct multiplies
|
||||
{
|
||||
@@ -111,8 +107,6 @@ struct multiplies<void, void>
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE_EXTERN multiplies() -> multiplies<void, void>;
|
||||
|
||||
template <typename T>
|
||||
struct maximize
|
||||
{
|
||||
@@ -341,8 +335,6 @@ struct equal<void, void>
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE_EXTERN equal() -> equal<void, void>;
|
||||
|
||||
template <>
|
||||
struct equal<float, float>
|
||||
{
|
||||
@@ -382,8 +374,6 @@ struct less<void, void>
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE_EXTERN less() -> less<void, void>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct less_equal
|
||||
{
|
||||
@@ -405,8 +395,6 @@ struct less_equal<void, void>
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE_EXTERN less_equal() -> less_equal<void, void>;
|
||||
|
||||
template <>
|
||||
struct less_equal<float, float>
|
||||
{
|
||||
|
||||
@@ -434,7 +434,7 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
|
||||
"the vector length is not the same!");
|
||||
constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1];
|
||||
constexpr index_t num_of_access =
|
||||
reduce_on_sequence(y_in_lengths, multiplies{}, number<1>{}) / vecLoadSize;
|
||||
reduce_on_sequence(y_in_lengths, multiplies<>{}, number<1>{}) / vecLoadSize;
|
||||
|
||||
using DataVec = array<typename BottomTensorView_::DataType, vecLoadSize>;
|
||||
static_for<0, num_of_access, 1>{}([&](auto iAccess) {
|
||||
|
||||
@@ -229,7 +229,7 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
|
||||
template <typename YLengths, index_t XUnpacks>
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number<XUnpacks>)
|
||||
{
|
||||
constexpr auto y_size = reduce_on_sequence(YLengths{}, multiplies{}, number<1>{});
|
||||
constexpr auto y_size = reduce_on_sequence(YLengths{}, multiplies<>{}, number<1>{});
|
||||
constexpr auto y_packs = number<XUnpacks>{};
|
||||
static_assert(y_size % y_packs == 0);
|
||||
constexpr auto y_slice_size = y_size / y_packs;
|
||||
|
||||
@@ -297,12 +297,12 @@ struct tile_sweeper
|
||||
|
||||
// partial deduction is not allowed
|
||||
// template <typename T, typename F, typename U>
|
||||
// CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const F&, U = {})->tile_sweeper<T, F, U>;
|
||||
// tile_sweeper(const F&, U = {})->tile_sweeper<T, F, U>;
|
||||
|
||||
// deduction guide
|
||||
template <typename T,
|
||||
typename F,
|
||||
typename U = typename uniform_sequence_gen<T::get_num_of_dimension(), 1>::type>
|
||||
CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {}) -> tile_sweeper<T, F, U>;
|
||||
tile_sweeper(const T&, const F&, U = {}) -> tile_sweeper<T, F, U>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -76,7 +76,7 @@ struct tensor_adaptor
|
||||
number<ndim_top_>{});
|
||||
|
||||
// TODO: make container_reduce support tuple of number and index_t
|
||||
return container_reduce(lengths, multiplies{}, number<1>{});
|
||||
return container_reduce(lengths, multiplies<>{}, number<1>{});
|
||||
}
|
||||
|
||||
template <index_t IDimHidden>
|
||||
|
||||
@@ -382,7 +382,7 @@ make_naive_tensor_descriptor_packed(const tuple<Lengths...>& lengths,
|
||||
|
||||
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
|
||||
|
||||
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
|
||||
const auto element_space_size = container_reduce(lengths, multiplies<>{}, long_number<1>{});
|
||||
|
||||
constexpr index_t first_dim_length = []() {
|
||||
if constexpr(is_constant_v<remove_cvref_t<decltype(element_space_size)>>)
|
||||
@@ -428,7 +428,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed_with_offs
|
||||
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
|
||||
{
|
||||
const auto desc_0 = [&]() {
|
||||
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
|
||||
const auto element_space_size = container_reduce(lengths, multiplies<>{}, long_number<1>{});
|
||||
|
||||
const auto transforms = make_tuple(make_offset_transform(element_space_size, offset));
|
||||
|
||||
@@ -491,8 +491,12 @@ make_naive_tensor_descriptor_aligned(const tuple<Lengths...>& lengths, Align ali
|
||||
}
|
||||
else
|
||||
{
|
||||
return container_reduce(
|
||||
lengths, multiplies{}, number<stride_n_minus_2>{}, i + I1, number<N - 1>{}, I1);
|
||||
return container_reduce(lengths,
|
||||
multiplies<>{},
|
||||
number<stride_n_minus_2>{},
|
||||
i + I1,
|
||||
number<N - 1>{},
|
||||
I1);
|
||||
}
|
||||
},
|
||||
number<N>{});
|
||||
|
||||
@@ -113,7 +113,7 @@ struct tile_distribution
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr index_t x_length =
|
||||
container_reduce(typename DstrEncode::HsLengthss{}[i], multiplies{}, 1);
|
||||
container_reduce(typename DstrEncode::HsLengthss{}[i], multiplies<>{}, 1);
|
||||
|
||||
return number<x_length>{};
|
||||
},
|
||||
@@ -583,8 +583,8 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
|
||||
if constexpr(x_slice_ends[i] == -1)
|
||||
{
|
||||
// -1 means till the end
|
||||
constexpr auto x_length_ =
|
||||
container_reduce(typename Encoding::HsLengthss{}[i], multiplies{}, number<1>{});
|
||||
constexpr auto x_length_ = container_reduce(
|
||||
typename Encoding::HsLengthss{}[i], multiplies<>{}, number<1>{});
|
||||
return x_length_;
|
||||
}
|
||||
else
|
||||
|
||||
@@ -277,7 +277,7 @@ struct tile_window_linear
|
||||
{
|
||||
constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{});
|
||||
constexpr auto is_pure_linear_tensor =
|
||||
reduce_on_sequence(LinearBottomDims{}, multiplies{}, number<1>{});
|
||||
reduce_on_sequence(LinearBottomDims{}, multiplies<>{}, number<1>{});
|
||||
if constexpr(is_pure_linear_tensor)
|
||||
{
|
||||
// this case usually is a LDS window, everything is known at compile tile.
|
||||
|
||||
@@ -69,9 +69,9 @@ struct static_uford_one_shot_impl
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds, number<current_acc>) const
|
||||
{
|
||||
constexpr auto r_lens_stride =
|
||||
reverse_exclusive_scan_sequence(RemainLengths{}, multiplies{}, number<1>{});
|
||||
reverse_exclusive_scan_sequence(RemainLengths{}, multiplies<>{}, number<1>{});
|
||||
constexpr auto r_upks_stride =
|
||||
reverse_exclusive_scan_sequence(RamainUnpacks{}, multiplies{}, number<1>{});
|
||||
reverse_exclusive_scan_sequence(RamainUnpacks{}, multiplies<>{}, number<1>{});
|
||||
|
||||
constexpr index_t current_stride = r_lens_stride.front() / r_upks_stride.front();
|
||||
constexpr index_t pack_len = RamainUnpacks::front();
|
||||
@@ -127,7 +127,7 @@ template <class Lengths,
|
||||
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
|
||||
struct static_uford
|
||||
{
|
||||
static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies{}, number<1>{});
|
||||
static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies<>{}, number<1>{});
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr static_uford()
|
||||
{
|
||||
@@ -142,7 +142,7 @@ struct static_uford
|
||||
{
|
||||
using L_ = decltype(Lengths{} / Unpacks{});
|
||||
|
||||
return reduce_on_sequence(L_{}, multiplies{}, number<1>{});
|
||||
return reduce_on_sequence(L_{}, multiplies<>{}, number<1>{});
|
||||
}
|
||||
|
||||
// F signature: F(sequence<...> multi_id...)
|
||||
|
||||
@@ -47,8 +47,11 @@ struct composes<F>
|
||||
F f_;
|
||||
};
|
||||
|
||||
template <typename... Ts>
|
||||
CK_TILE_HOST_DEVICE_EXTERN composes(Ts&&...) -> composes<remove_cvref_t<Ts>...>;
|
||||
template <class... Ts>
|
||||
CK_TILE_HOST_DEVICE constexpr auto make_composes(Ts&&... ts)
|
||||
{
|
||||
return composes<remove_cvref_t<Ts>...>{std::forward<Ts>(ts)...};
|
||||
}
|
||||
|
||||
template <typename SaturateType>
|
||||
struct saturates
|
||||
|
||||
@@ -65,7 +65,7 @@ inline bool is_gfx12_supported()
|
||||
return get_device_name() == "gfx1200" || get_device_name() == "gfx1201";
|
||||
}
|
||||
|
||||
inline bool is_load_tr_supported()
|
||||
inline bool is_gfx95_supported()
|
||||
{
|
||||
// Check if load transpose is supported.
|
||||
return get_device_name() == "gfx950";
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "device_prop.hpp"
|
||||
#include <stdexcept>
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -98,7 +99,7 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
|
||||
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
|
||||
gemmConfig.N_Warp_Tile,
|
||||
@@ -167,7 +168,7 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmC
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
|
||||
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
|
||||
gemmConfig.N_Warp,
|
||||
|
||||
@@ -24,7 +24,7 @@ struct ElementWiseShape
|
||||
static constexpr index_t kRepeatM = kBlockM / (kWarpPerBlockM * kVectorM * kThreadPerWarpM);
|
||||
|
||||
static constexpr index_t kBlockSize =
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -19,7 +19,8 @@ struct TileFlatmmShape
|
||||
static constexpr auto idxN = number<1>{};
|
||||
static constexpr auto idxK = number<2>{};
|
||||
|
||||
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
static constexpr index_t NumWarps =
|
||||
reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
|
||||
static constexpr index_t kM = BlockTile::at(idxM);
|
||||
static constexpr index_t kN = BlockTile::at(idxN);
|
||||
|
||||
@@ -1193,39 +1193,40 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
|
||||
auto o_acc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o});
|
||||
return make_composes(saturates<ck_tile::fp8_t>{},
|
||||
scales<remove_cvref_t<decltype(scale_o)>>{scale_o});
|
||||
else
|
||||
return ck_tile::scales{scale_o};
|
||||
return scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
|
||||
}();
|
||||
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
page_idx,
|
||||
stride_k_for_pipeline,
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout);
|
||||
return FmhaPipeline{}(
|
||||
q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales<remove_cvref_t<decltype(scale_p)>>{scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
page_idx,
|
||||
stride_k_for_pipeline,
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -1538,10 +1538,11 @@ struct FmhaFwdKernel
|
||||
|
||||
auto o_acc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
|
||||
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales{scale_o});
|
||||
return make_composes(
|
||||
ck_tile::saturates<ck_tile::fp8_t>{},
|
||||
ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o});
|
||||
else
|
||||
return ck_tile::scales{scale_o};
|
||||
return ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
|
||||
}();
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
@@ -1553,9 +1554,10 @@ struct FmhaFwdKernel
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{scale_p}, // p_compute_element_func
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales<remove_cvref_t<decltype(scale_p)>>{
|
||||
scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
|
||||
@@ -1325,30 +1325,32 @@ struct FmhaFwdPagedKVKernel
|
||||
auto o_acc_tile = [&]() {
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
return FmhaPipeline{}(
|
||||
q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window_lengths,
|
||||
k_page_block_navigator,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window_lengths,
|
||||
v_page_block_navigator,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{kargs.scale_p}, // p_compute_element_func
|
||||
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window_lengths,
|
||||
k_page_block_navigator,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window_lengths,
|
||||
v_page_block_navigator,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales<remove_cvref_t<decltype(kargs.scale_p)>>{
|
||||
kargs.scale_p}, // p_compute_element_func
|
||||
make_composes(saturates<fp8_t>{},
|
||||
scales<remove_cvref_t<decltype(kargs.scale_o)>>{
|
||||
kargs.scale_o}), // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -457,14 +457,15 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
auto o_acc_tile = [&]() {
|
||||
if constexpr(kDoFp8StaticQuant)
|
||||
{
|
||||
return FmhaPipeline{}(
|
||||
lse_acc_dram_window,
|
||||
o_acc_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
|
||||
kargs.num_splits,
|
||||
smem_ptr);
|
||||
return FmhaPipeline{}(lse_acc_dram_window,
|
||||
o_acc_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
make_composes(saturates<fp8_t>{},
|
||||
scales<remove_cvref_t<decltype(kargs.scale_o)>>{
|
||||
kargs.scale_o}), // o_acc_element_func
|
||||
kargs.num_splits,
|
||||
smem_ptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -1069,10 +1069,11 @@ struct FmhaFwdSplitKVKernel
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
lse_acc_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales{kargs.scale_p}, // p_compute_element_func
|
||||
identity{}, // o_acc_element_func
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales<remove_cvref_t<decltype(kargs.scale_p)>>{
|
||||
kargs.scale_p}, // p_compute_element_func
|
||||
identity{}, // o_acc_element_func
|
||||
kargs.num_splits,
|
||||
i_split_,
|
||||
mask,
|
||||
|
||||
@@ -42,9 +42,9 @@ struct TileFmhaShape
|
||||
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
|
||||
|
||||
static constexpr index_t NumGemm0Warps =
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
static constexpr index_t NumGemm1Warps =
|
||||
reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{});
|
||||
reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
|
||||
|
||||
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
|
||||
@@ -95,10 +95,10 @@ struct TileFmhaBwdShape
|
||||
using Gemm4WarpTile = remove_cvref_t<Gemm4WarpTile_>;
|
||||
|
||||
static constexpr index_t NumWarps =
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
|
||||
static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}) &&
|
||||
NumWarps == reduce_on_sequence(Gemm4BlockWarps{}, multiplies{}, number<1>{}));
|
||||
static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{}) &&
|
||||
NumWarps == reduce_on_sequence(Gemm4BlockWarps{}, multiplies<>{}, number<1>{}));
|
||||
|
||||
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
|
||||
|
||||
@@ -56,10 +56,10 @@ struct FusedMoeGemmShape
|
||||
using WarpTile_1 = remove_cvref_t<WarpTile_1_>;
|
||||
|
||||
static constexpr index_t NumWarps =
|
||||
reduce_on_sequence(WarpPerBlock_0{}, multiplies{}, number<1>{});
|
||||
reduce_on_sequence(WarpPerBlock_0{}, multiplies<>{}, number<1>{});
|
||||
|
||||
// TODO: we don't support half warps aound to 1 warp here
|
||||
static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{}));
|
||||
static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies<>{}, number<1>{}));
|
||||
|
||||
static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{});
|
||||
static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{});
|
||||
|
||||
@@ -19,7 +19,8 @@ struct TileGemmShape
|
||||
using BlockWarps = remove_cvref_t<BlockWarps_>;
|
||||
using WarpTile = remove_cvref_t<WarpTile_>;
|
||||
|
||||
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
static constexpr index_t NumWarps =
|
||||
reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
|
||||
static constexpr index_t kM = BlockTile::at(number<0>{});
|
||||
static constexpr index_t kN = BlockTile::at(number<1>{});
|
||||
|
||||
@@ -52,6 +52,6 @@ struct PoolShape
|
||||
static constexpr index_t Repeat_N = Block_N * WarpSizeScaleFactor_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
static constexpr index_t BlockSize =
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -345,7 +345,7 @@ struct BlockReduce2D
|
||||
constexpr auto row_y_unpacks = [&]() {
|
||||
constexpr auto row_y_lengths = typename decltype(spans[number<1>{}])::Impl{};
|
||||
constexpr auto row_y_size =
|
||||
reduce_on_sequence(row_y_lengths, multiplies{}, number<1>{});
|
||||
reduce_on_sequence(row_y_lengths, multiplies<>{}, number<1>{});
|
||||
constexpr auto row_y_packs = ReducePacksPerXDim{}.at(number<1>{});
|
||||
|
||||
static_assert(row_y_size % row_y_packs == 0);
|
||||
|
||||
@@ -39,6 +39,6 @@ struct Reduce2dShape
|
||||
static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
static constexpr index_t BlockSize =
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -96,7 +96,7 @@ struct TopkSoftmaxWarpPerRowPipeline
|
||||
w_(idx) = WeightType(1) / (WeightType(1) + exp(-w_(idx)));
|
||||
}
|
||||
};
|
||||
tile_sweeper ts{w_, w_f};
|
||||
tile_sweeper<decltype(w_), decltype(w_f)> ts{w_, w_f};
|
||||
ts();
|
||||
return w_;
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user