Shuffle fix for gfx950 (#3491)

* solve compiler issue

* solve the gfx950 mfma shuffle regression

* refactor jenkinsfile to handle arch name better

* [CK TILE] set divisor to count of thread along k dimension

* fix the compiler error

* solve degradation

* Finish the multiplies fix

* fix the scales

* solve compilation error

* solve the composes

* solve the error of tile sweeper

* fix the test and example

* fix for gfx950

---------

Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
Co-authored-by: Cong Ma <congma13@amd.com>
This commit is contained in:
Thomas Ning
2026-01-14 01:21:29 +08:00
committed by GitHub
parent 9908a87c31
commit 00c46785a8
33 changed files with 161 additions and 152 deletions

View File

@@ -69,9 +69,9 @@ struct static_uford_one_shot_impl
CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentUnpackIds, number<current_acc>) const
{
constexpr auto r_lens_stride =
reverse_exclusive_scan_sequence(RemainLengths{}, multiplies{}, number<1>{});
reverse_exclusive_scan_sequence(RemainLengths{}, multiplies<>{}, number<1>{});
constexpr auto r_upks_stride =
reverse_exclusive_scan_sequence(RamainUnpacks{}, multiplies{}, number<1>{});
reverse_exclusive_scan_sequence(RamainUnpacks{}, multiplies<>{}, number<1>{});
constexpr index_t current_stride = r_lens_stride.front() / r_upks_stride.front();
constexpr index_t pack_len = RamainUnpacks::front();
@@ -127,7 +127,7 @@ template <class Lengths,
class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
struct static_uford
{
static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies{}, number<1>{});
static constexpr index_t num_packs = reduce_on_sequence(Unpacks{}, multiplies<>{}, number<1>{});
CK_TILE_HOST_DEVICE constexpr static_uford()
{
@@ -142,7 +142,7 @@ struct static_uford
{
using L_ = decltype(Lengths{} / Unpacks{});
return reduce_on_sequence(L_{}, multiplies{}, number<1>{});
return reduce_on_sequence(L_{}, multiplies<>{}, number<1>{});
}
// F signature: F(sequence<...> multi_id...)

View File

@@ -47,8 +47,11 @@ struct composes<F>
F f_;
};
template <typename... Ts>
CK_TILE_HOST_DEVICE_EXTERN composes(Ts&&...) -> composes<remove_cvref_t<Ts>...>;
template <class... Ts>
CK_TILE_HOST_DEVICE constexpr auto make_composes(Ts&&... ts)
{
return composes<remove_cvref_t<Ts>...>{std::forward<Ts>(ts)...};
}
template <typename SaturateType>
struct saturates