mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
[CK_TILE] Fix cshuffle epilogue issue with IsLoadableTile (#2903)
* Fix issue with constexpr checks in scaling/cshuffle * Remove IsLoadableTile * Move amd_wave_read_first_lane before first usage
This commit is contained in:
@@ -481,13 +481,10 @@ struct CShuffleEpilogue
|
||||
auto sm_tile = make_static_distributed_tensor<SMType>(dram_tile_distribution);
|
||||
auto sn_tile = make_static_distributed_tensor<SNType>(dram_tile_distribution);
|
||||
|
||||
// Build windows only if scales are provided
|
||||
// Build windows only if non-scalar scales are provided
|
||||
auto scale_m_window = [&]() {
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
static_assert(
|
||||
IsLoadableTile<decltype(make_tile_window(scale_m, dram_tile_distribution))>,
|
||||
"ScaleM must be a loadable tile");
|
||||
return make_tile_window(scale_m, dram_tile_distribution);
|
||||
}
|
||||
else
|
||||
@@ -498,9 +495,6 @@ struct CShuffleEpilogue
|
||||
auto scale_n_window = [&]() {
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
static_assert(
|
||||
IsLoadableTile<decltype(make_tile_window(scale_n, dram_tile_distribution))>,
|
||||
"ScaleN must be a loadable tile");
|
||||
return make_tile_window(scale_n, dram_tile_distribution);
|
||||
}
|
||||
else
|
||||
@@ -515,8 +509,8 @@ struct CShuffleEpilogue
|
||||
merge_sequences(sequence<mIter, 0>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths));
|
||||
|
||||
// If scales provided, load them with identical distribution
|
||||
if constexpr(has_scales && IsLoadableTile<ScaleM> && IsLoadableTile<ScaleN>)
|
||||
// If non-scalar scales provided, load them with identical distribution
|
||||
if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
sm_tile = load_tile(scale_m_window); // row scales in permuted layout
|
||||
sn_tile = load_tile(scale_n_window); // col scales in permuted layout
|
||||
@@ -535,7 +529,7 @@ struct CShuffleEpilogue
|
||||
{
|
||||
v = static_cast<AccDataType>(v * scale_m * scale_n);
|
||||
}
|
||||
else if constexpr(has_scales)
|
||||
else if constexpr(has_scales && !has_scalar_scales)
|
||||
{
|
||||
// same linear index mapping on the permuted distribution
|
||||
const auto s_m = static_cast<float>(sm_tile.get_thread_buffer()[out_idx]);
|
||||
@@ -636,9 +630,6 @@ struct CShuffleEpilogue
|
||||
}
|
||||
else if constexpr(has_scales)
|
||||
{
|
||||
static_assert(
|
||||
IsLoadableTile<decltype(make_tile_window(scale_m, dram_tile_distribution))>,
|
||||
"ScaleM must be a loadable tile");
|
||||
return make_tile_window(scale_m, lds_tile.get_tile_distribution());
|
||||
}
|
||||
else
|
||||
@@ -653,9 +644,6 @@ struct CShuffleEpilogue
|
||||
}
|
||||
else if constexpr(has_scales)
|
||||
{
|
||||
static_assert(
|
||||
IsLoadableTile<decltype(make_tile_window(scale_n, dram_tile_distribution))>,
|
||||
"ScaleN must be a loadable tile");
|
||||
return make_tile_window(scale_n, lds_tile.get_tile_distribution());
|
||||
}
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user