[CK_TILE] Fix cshuffle epilogue issue with IsLoadableTile (#2903)

* Fix issue with constexpr checks in scaling/cshuffle

* Remove IsLoadableTile

* Move amd_wave_read_first_lane before first usage
This commit is contained in:
Sami Remes
2025-09-24 09:08:18 +03:00
committed by GitHub
parent b159841a06
commit dcd33a6ecc
3 changed files with 58 additions and 19 deletions

View File

@@ -481,13 +481,10 @@ struct CShuffleEpilogue
auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
// Build windows only if scales are provided
// Build windows only if non-scalar scales are provided
auto scale_m_window = [&]() {
if constexpr(has_scales && !has_scalar_scales)
{
static_assert(
IsLoadableTile<decltype(make_tile_window(scale_m, dram_tile_distribution))>,
"ScaleM must be a loadable tile");
return make_tile_window(scale_m, dram_tile_distribution);
}
else
@@ -498,9 +495,6 @@ struct CShuffleEpilogue
auto scale_n_window = [&]() {
if constexpr(has_scales && !has_scalar_scales)
{
static_assert(
IsLoadableTile<decltype(make_tile_window(scale_n, dram_tile_distribution))>,
"ScaleN must be a loadable tile");
return make_tile_window(scale_n, dram_tile_distribution);
}
else
@@ -515,8 +509,8 @@ struct CShuffleEpilogue
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
// If scales provided, load them with identical distribution
if constexpr(has_scales && IsLoadableTile<ScaleM> && IsLoadableTile<ScaleN>)
// If non-scalar scales provided, load them with identical distribution
if constexpr(has_scales && !has_scalar_scales)
{
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
@@ -535,7 +529,7 @@ struct CShuffleEpilogue
{
v = static_cast<AccDataType>(v * scale_m * scale_n);
}
else if constexpr(has_scales)
else if constexpr(has_scales && !has_scalar_scales)
{
// same linear index mapping on the permuted distribution
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
@@ -636,9 +630,6 @@ struct CShuffleEpilogue
}
else if constexpr(has_scales)
{
static_assert(
IsLoadableTile<decltype(make_tile_window(scale_m, dram_tile_distribution))>,
"ScaleM must be a loadable tile");
return make_tile_window(scale_m, lds_tile.get_tile_distribution());
}
else
@@ -653,9 +644,6 @@ struct CShuffleEpilogue
}
else if constexpr(has_scales)
{
static_assert(
IsLoadableTile<decltype(make_tile_window(scale_n, dram_tile_distribution))>,
"ScaleN must be a loadable tile");
return make_tile_window(scale_n, lds_tile.get_tile_distribution());
}
else