Shuffle fix for gfx950 (#3491)

* solve compiler issue

* solve the gfx950 mfma shuffle regression

* refactor jenkinsfile to handle arch name better

* [CK TILE] set divisor to count of thread along k dimension

* fix the compiler error

* solve degradation

* Finish the multiplies fix

* fix the scales

* solve compilation error

* solve the composes

* solve the error of tile sweeper

* fix the test and example

* fix for gfx950

---------

Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
Co-authored-by: Cong Ma <congma13@amd.com>
This commit is contained in:
Thomas Ning
2026-01-14 01:21:29 +08:00
committed by GitHub
parent 9908a87c31
commit 00c46785a8
33 changed files with 161 additions and 152 deletions

View File

@@ -564,7 +564,7 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
using UpperIndex = multi_index<1>;
using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{})));
decltype(make_tuple(container_reduce(LowLengths{}, multiplies<>{}, number<1>{})));
using LowLengthsMagicDivisor = decltype(generate_tuple(
lambda_merge_generate_MagicDivision_calculate_magic_divisor<LowLengths>{},
@@ -584,7 +584,7 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
low_lengths_magic_divisor_{generate_tuple(
[&](auto i) { return magic_division::calculate_magic_numbers(low_lengths[i]); },
number<NDimLow>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, I1))}
up_lengths_{make_tuple(container_reduce(low_lengths, multiplies<>{}, I1))}
{
static_assert(LowerIndex::size() == NDimLow, "wrong!");
}
@@ -707,10 +707,10 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
using UpperIndex = multi_index<1>;
using LowLengthsScan =
decltype(container_reverse_exclusive_scan(LowLengths{}, multiplies{}, number<1>{}));
decltype(container_reverse_exclusive_scan(LowLengths{}, multiplies<>{}, number<1>{}));
using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{})));
decltype(make_tuple(container_reduce(LowLengths{}, multiplies<>{}, number<1>{})));
LowLengths low_lengths_;
LowLengthsScan low_lengths_scan_;
@@ -721,8 +721,8 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod(const LowLengths& low_lengths)
: low_lengths_{low_lengths},
low_lengths_scan_{
container_reverse_exclusive_scan(low_lengths, multiplies{}, number<1>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, number<1>{}))}
container_reverse_exclusive_scan(low_lengths, multiplies<>{}, number<1>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, multiplies<>{}, number<1>{}))}
{
static_assert(LowerIndex::size() == NDimLow, "wrong!");
}
@@ -832,7 +832,7 @@ struct unmerge : public base_transform<1, UpLengths::size()>
using UpperIndex = multi_index<NDimUp>;
using UpLengthsScan =
decltype(container_reverse_exclusive_scan(UpLengths{}, multiplies{}, number<1>{}));
decltype(container_reverse_exclusive_scan(UpLengths{}, multiplies<>{}, number<1>{}));
UpLengths up_lengths_;
UpLengthsScan up_lengths_scan_;
@@ -841,7 +841,8 @@ struct unmerge : public base_transform<1, UpLengths::size()>
CK_TILE_HOST_DEVICE constexpr unmerge(const UpLengths& up_lengths)
: up_lengths_{up_lengths},
up_lengths_scan_{container_reverse_exclusive_scan(up_lengths, multiplies{}, number<1>{})}
up_lengths_scan_{
container_reverse_exclusive_scan(up_lengths, multiplies<>{}, number<1>{})}
{
}

View File

@@ -19,7 +19,7 @@ template <typename TensorLengths,
struct space_filling_curve
{
static constexpr index_t TensorSize =
reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{});
reduce_on_sequence(TensorLengths{}, multiplies<>{}, number<1>{});
static_assert(0 < TensorSize,
"space_filling_curve should be used to access a non-empty tensor");
@@ -28,7 +28,7 @@ struct space_filling_curve
using Index = multi_index<nDim>;
static constexpr index_t ScalarPerVector =
reduce_on_sequence(ScalarsPerAccess{}, multiplies{}, number<1>{});
reduce_on_sequence(ScalarsPerAccess{}, multiplies<>{}, number<1>{});
static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{};
static constexpr auto dim_access_order = DimAccessOrder{};
@@ -49,7 +49,7 @@ struct space_filling_curve
static_assert(TensorLengths{} % ScalarsPerAccess{} ==
typename uniform_sequence_gen<TensorLengths::size(), 0>::type{});
return reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{}) / ScalarPerVector;
return reduce_on_sequence(TensorLengths{}, multiplies<>{}, number<1>{}) / ScalarPerVector;
}
template <index_t AccessIdx1dHead, index_t AccessIdx1dTail>
@@ -94,7 +94,7 @@ struct space_filling_curve
#else
constexpr auto access_strides =
container_reverse_exclusive_scan(ordered_access_lengths, multiplies{}, number<1>{});
container_reverse_exclusive_scan(ordered_access_lengths, multiplies<>{}, number<1>{});
constexpr auto idx_1d = number<AccessIdx1d>{};
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the

View File

@@ -1237,10 +1237,11 @@ constexpr auto reverse_slice_sequence(Seq,
{
static_assert(Seq::size() == Mask::size());
static_assert(SliceSize != 0, "slice size zero is invalid");
static_assert(container_reduce(pick_sequence_elements_by_mask(Seq{}, Mask{}), multiplies{}, 1) %
SliceSize ==
0,
"slice size can't evenly divide input sizes");
static_assert(
container_reduce(pick_sequence_elements_by_mask(Seq{}, Mask{}), multiplies<>{}, 1) %
SliceSize ==
0,
"slice size can't evenly divide input sizes");
using sliced_type =
impl::reverse_slice_sequence_impl<Seq,
Mask,

View File

@@ -42,7 +42,7 @@ struct scales
};
template <typename Scale>
CK_TILE_HOST_DEVICE_EXTERN scales(Scale) -> scales<Scale>;
scales(Scale) -> scales<Scale>;
template <typename Left = void, typename Right = Left>
struct plus
@@ -65,8 +65,6 @@ struct plus<void, void>
}
};
CK_TILE_HOST_DEVICE_EXTERN plus() -> plus<void, void>;
template <typename Left = void, typename Right = Left>
struct minus
{
@@ -88,8 +86,6 @@ struct minus<void, void>
}
};
CK_TILE_HOST_DEVICE_EXTERN minus() -> minus<void, void>;
template <typename Left = void, typename Right = Left>
struct multiplies
{
@@ -111,8 +107,6 @@ struct multiplies<void, void>
}
};
CK_TILE_HOST_DEVICE_EXTERN multiplies() -> multiplies<void, void>;
template <typename T>
struct maximize
{
@@ -341,8 +335,6 @@ struct equal<void, void>
}
};
CK_TILE_HOST_DEVICE_EXTERN equal() -> equal<void, void>;
template <>
struct equal<float, float>
{
@@ -382,8 +374,6 @@ struct less<void, void>
}
};
CK_TILE_HOST_DEVICE_EXTERN less() -> less<void, void>;
template <typename Left = void, typename Right = Left>
struct less_equal
{
@@ -405,8 +395,6 @@ struct less_equal<void, void>
}
};
CK_TILE_HOST_DEVICE_EXTERN less_equal() -> less_equal<void, void>;
template <>
struct less_equal<float, float>
{

View File

@@ -434,7 +434,7 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
"the vector length is not the same!");
constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1];
constexpr index_t num_of_access =
reduce_on_sequence(y_in_lengths, multiplies{}, number<1>{}) / vecLoadSize;
reduce_on_sequence(y_in_lengths, multiplies<>{}, number<1>{}) / vecLoadSize;
using DataVec = array<typename BottomTensorView_::DataType, vecLoadSize>;
static_for<0, num_of_access, 1>{}([&](auto iAccess) {

View File

@@ -229,7 +229,7 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
template <typename YLengths, index_t XUnpacks>
CK_TILE_HOST_DEVICE constexpr auto get_y_unpacks_from_x_unpacks(YLengths, number<XUnpacks>)
{
constexpr auto y_size = reduce_on_sequence(YLengths{}, multiplies{}, number<1>{});
constexpr auto y_size = reduce_on_sequence(YLengths{}, multiplies<>{}, number<1>{});
constexpr auto y_packs = number<XUnpacks>{};
static_assert(y_size % y_packs == 0);
constexpr auto y_slice_size = y_size / y_packs;

View File

@@ -297,12 +297,12 @@ struct tile_sweeper
// partial deduction is not allowed
// template <typename T, typename F, typename U>
// CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const F&, U = {})->tile_sweeper<T, F, U>;
// tile_sweeper(const F&, U = {})->tile_sweeper<T, F, U>;
// deduction guide
template <typename T,
typename F,
typename U = typename uniform_sequence_gen<T::get_num_of_dimension(), 1>::type>
CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {}) -> tile_sweeper<T, F, U>;
tile_sweeper(const T&, const F&, U = {}) -> tile_sweeper<T, F, U>;
} // namespace ck_tile

View File

@@ -76,7 +76,7 @@ struct tensor_adaptor
number<ndim_top_>{});
// TODO: make container_reduce support tuple of number and index_t
return container_reduce(lengths, multiplies{}, number<1>{});
return container_reduce(lengths, multiplies<>{}, number<1>{});
}
template <index_t IDimHidden>

View File

@@ -382,7 +382,7 @@ make_naive_tensor_descriptor_packed(const tuple<Lengths...>& lengths,
constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{};
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
const auto element_space_size = container_reduce(lengths, multiplies<>{}, long_number<1>{});
constexpr index_t first_dim_length = []() {
if constexpr(is_constant_v<remove_cvref_t<decltype(element_space_size)>>)
@@ -428,7 +428,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor_packed_with_offs
number<GuaranteedLastDimensionVectorLength> = number<-1>{})
{
const auto desc_0 = [&]() {
const auto element_space_size = container_reduce(lengths, multiplies{}, long_number<1>{});
const auto element_space_size = container_reduce(lengths, multiplies<>{}, long_number<1>{});
const auto transforms = make_tuple(make_offset_transform(element_space_size, offset));
@@ -491,8 +491,12 @@ make_naive_tensor_descriptor_aligned(const tuple<Lengths...>& lengths, Align ali
}
else
{
return container_reduce(
lengths, multiplies{}, number<stride_n_minus_2>{}, i + I1, number<N - 1>{}, I1);
return container_reduce(lengths,
multiplies<>{},
number<stride_n_minus_2>{},
i + I1,
number<N - 1>{},
I1);
}
},
number<N>{});

View File

@@ -113,7 +113,7 @@ struct tile_distribution
return generate_tuple(
[&](auto i) {
constexpr index_t x_length =
container_reduce(typename DstrEncode::HsLengthss{}[i], multiplies{}, 1);
container_reduce(typename DstrEncode::HsLengthss{}[i], multiplies<>{}, 1);
return number<x_length>{};
},
@@ -583,8 +583,8 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
if constexpr(x_slice_ends[i] == -1)
{
// -1 means till the end
constexpr auto x_length_ =
container_reduce(typename Encoding::HsLengthss{}[i], multiplies{}, number<1>{});
constexpr auto x_length_ = container_reduce(
typename Encoding::HsLengthss{}[i], multiplies<>{}, number<1>{});
return x_length_;
}
else

View File

@@ -277,7 +277,7 @@ struct tile_window_linear
{
constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{});
constexpr auto is_pure_linear_tensor =
reduce_on_sequence(LinearBottomDims{}, multiplies{}, number<1>{});
reduce_on_sequence(LinearBottomDims{}, multiplies<>{}, number<1>{});
if constexpr(is_pure_linear_tensor)
{
// this case usually is a LDS window, everything is known at compile tile.

View File

@@ -69,9 +69,9 @@ struct static_uford_one_shot_impl
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds, number<current_acc>) const
{
constexpr auto r_lens_stride =
reverse_exclusive_scan_sequence(RemainLengths{}, multiplies{}, number<1>{});
reverse_exclusive_scan_sequence(RemainLengths{}, multiplies<>{}, number<1>{});
constexpr auto r_upks_stride =
reverse_exclusive_scan_sequence(RamainUnpacks{}, multiplies{}, number<1>{});
reverse_exclusive_scan_sequence(RamainUnpacks{}, multiplies<>{}, number<1>{});
constexpr index_t current_stride = r_lens_stride.front() / r_upks_stride.front();
constexpr index_t pack_len = RamainUnpacks::front();
@@ -127,7 +127,7 @@ template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
struct static_uford
{
static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies{}, number<1>{});
static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies<>{}, number<1>{});
CK_TILE_HOST_DEVICE constexpr static_uford()
{
@@ -142,7 +142,7 @@ struct static_uford
{
using L_ = decltype(Lengths{} / Unpacks{});
return reduce_on_sequence(L_{}, multiplies{}, number<1>{});
return reduce_on_sequence(L_{}, multiplies<>{}, number<1>{});
}
// F signature: F(sequence<...> multi_id...)

View File

@@ -47,8 +47,11 @@ struct composes<F>
F f_;
};
template <typename... Ts>
CK_TILE_HOST_DEVICE_EXTERN composes(Ts&&...) -> composes<remove_cvref_t<Ts>...>;
template <class... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_composes(Ts&&... ts)
{
return composes<remove_cvref_t<Ts>...>{std::forward<Ts>(ts)...};
}
template <typename SaturateType>
struct saturates

View File

@@ -65,7 +65,7 @@ inline bool is_gfx12_supported()
return get_device_name() == "gfx1200" || get_device_name() == "gfx1201";
}
inline bool is_load_tr_supported()
inline bool is_gfx95_supported()
{
// Check if load transpose is supported.
return get_device_name() == "gfx950";

View File

@@ -2,6 +2,7 @@
// SPDX-License-Identifier: MIT
#pragma once
#include "device_prop.hpp"
#include <stdexcept>
namespace ck_tile {
@@ -98,7 +99,7 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmConfig)
else
{
assert(is_wave32() == false);
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
}
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Warp_Tile,
gemmConfig.N_Warp_Tile,
@@ -167,7 +168,7 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmC
else
{
assert(is_wave32() == false);
divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4;
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
}
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,

View File

@@ -24,7 +24,7 @@ struct ElementWiseShape
static constexpr index_t kRepeatM = kBlockM / (kWarpPerBlockM * kVectorM * kThreadPerWarpM);
static constexpr index_t kBlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
};
} // namespace ck_tile

View File

@@ -19,7 +19,8 @@ struct TileFlatmmShape
static constexpr auto idxN = number<1>{};
static constexpr auto idxK = number<2>{};
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t NumWarps =
reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
static constexpr index_t kM = BlockTile::at(idxM);
static constexpr index_t kN = BlockTile::at(idxN);

View File

@@ -1193,39 +1193,40 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
auto o_acc_element_func = [&]() {
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o});
return make_composes(saturates<ck_tile::fp8_t>{},
scales<remove_cvref_t<decltype(scale_o)>>{scale_o});
else
return ck_tile::scales{scale_o};
return scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
}();
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{scale_p}, // p_compute_element_func
o_acc_element_func, // o_acc_element_func
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout);
return FmhaPipeline{}(
q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(scale_p)>>{scale_p}, // p_compute_element_func
o_acc_element_func, // o_acc_element_func
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
page_idx,
stride_k_for_pipeline,
stride_v_for_pipeline,
kargs.batch_stride_k,
kargs.batch_stride_v,
dropout);
}
else
{

View File

@@ -1538,10 +1538,11 @@ struct FmhaFwdKernel
auto o_acc_element_func = [&]() {
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
return ck_tile::composes(ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales{scale_o});
return make_composes(
ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o});
else
return ck_tile::scales{scale_o};
return ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
}();
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
@@ -1553,9 +1554,10 @@ struct FmhaFwdKernel
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{scale_p}, // p_compute_element_func
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(scale_p)>>{
scale_p}, // p_compute_element_func
o_acc_element_func, // o_acc_element_func
mask,
position_encoding,

View File

@@ -1325,30 +1325,32 @@ struct FmhaFwdPagedKVKernel
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
return FmhaPipeline{}(
q_dram_window,
identity{}, // q_element_func
k_dram_window_lengths,
k_page_block_navigator,
identity{}, // k_element_func
v_dram_window_lengths,
v_page_block_navigator,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{kargs.scale_p}, // p_compute_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
k_dram_window_lengths,
k_page_block_navigator,
identity{}, // k_element_func
v_dram_window_lengths,
v_page_block_navigator,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(kargs.scale_p)>>{
kargs.scale_p}, // p_compute_element_func
make_composes(saturates<fp8_t>{},
scales<remove_cvref_t<decltype(kargs.scale_o)>>{
kargs.scale_o}), // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
kv_l2p_offset,
smem_ptr);
}
else
{

View File

@@ -457,14 +457,15 @@ struct FmhaFwdSplitKVCombineKernel
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
return FmhaPipeline{}(
lse_acc_dram_window,
o_acc_dram_window,
lse_dram_window,
identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
kargs.num_splits,
smem_ptr);
return FmhaPipeline{}(lse_acc_dram_window,
o_acc_dram_window,
lse_dram_window,
identity{}, // lse_element_func
make_composes(saturates<fp8_t>{},
scales<remove_cvref_t<decltype(kargs.scale_o)>>{
kargs.scale_o}), // o_acc_element_func
kargs.num_splits,
smem_ptr);
}
else
{

View File

@@ -1069,10 +1069,11 @@ struct FmhaFwdSplitKVKernel
bias_dram_window,
identity{}, // bias_element_func
lse_acc_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales{kargs.scale_p}, // p_compute_element_func
identity{}, // o_acc_element_func
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(kargs.scale_p)>>{
kargs.scale_p}, // p_compute_element_func
identity{}, // o_acc_element_func
kargs.num_splits,
i_split_,
mask,

View File

@@ -42,9 +42,9 @@ struct TileFmhaShape
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
static constexpr index_t NumGemm0Warps =
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{});
static constexpr index_t NumGemm1Warps =
reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{});
reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{});
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
@@ -95,10 +95,10 @@ struct TileFmhaBwdShape
using Gemm4WarpTile = remove_cvref_t<Gemm4WarpTile_>;
static constexpr index_t NumWarps =
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
reduce_on_sequence(Gemm0BlockWarps{}, multiplies<>{}, number<1>{});
static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}) &&
NumWarps == reduce_on_sequence(Gemm4BlockWarps{}, multiplies{}, number<1>{}));
static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies<>{}, number<1>{}) &&
NumWarps == reduce_on_sequence(Gemm4BlockWarps{}, multiplies<>{}, number<1>{}));
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen

View File

@@ -56,10 +56,10 @@ struct FusedMoeGemmShape
using WarpTile_1 = remove_cvref_t<WarpTile_1_>;
static constexpr index_t NumWarps =
reduce_on_sequence(WarpPerBlock_0{}, multiplies{}, number<1>{});
reduce_on_sequence(WarpPerBlock_0{}, multiplies<>{}, number<1>{});
// TODO: we don't support half warps aound to 1 warp here
static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{}));
static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies<>{}, number<1>{}));
static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{});
static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{});

View File

@@ -19,7 +19,8 @@ struct TileGemmShape
using BlockWarps = remove_cvref_t<BlockWarps_>;
using WarpTile = remove_cvref_t<WarpTile_>;
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t NumWarps =
reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
static constexpr index_t kM = BlockTile::at(number<0>{});
static constexpr index_t kN = BlockTile::at(number<1>{});

View File

@@ -52,6 +52,6 @@ struct PoolShape
static constexpr index_t Repeat_N = Block_N * WarpSizeScaleFactor_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t BlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
};
} // namespace ck_tile

View File

@@ -345,7 +345,7 @@ struct BlockReduce2D
constexpr auto row_y_unpacks = [&]() {
constexpr auto row_y_lengths = typename decltype(spans[number<1>{}])::Impl{};
constexpr auto row_y_size =
reduce_on_sequence(row_y_lengths, multiplies{}, number<1>{});
reduce_on_sequence(row_y_lengths, multiplies<>{}, number<1>{});
constexpr auto row_y_packs = ReducePacksPerXDim{}.at(number<1>{});
static_assert(row_y_size % row_y_packs == 0);

View File

@@ -39,6 +39,6 @@ struct Reduce2dShape
static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t BlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
};
} // namespace ck_tile

View File

@@ -96,7 +96,7 @@ struct TopkSoftmaxWarpPerRowPipeline
w_(idx) = WeightType(1) / (WeightType(1) + exp(-w_(idx)));
}
};
tile_sweeper ts{w_, w_f};
tile_sweeper<decltype(w_), decltype(w_f)> ts{w_, w_f};
ts();
return w_;
#endif