mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Fix CK Tile DP + 2 Tile Stream-K Validation Errors (#3269)
When there are multiple workgroups contributing to a tile, when using atomics, there may be round off error in cases where the accumulator type is not the same as the C type. To compute an error tolerance for test validation, the Stream-K Tile Partitioner has a function called estimate_num_wgs_per_tile to estimate the number of workgroups per tile. That said, this function only provides an estimate. In some cases for DP+2TSK, the function returns 1 rather than the more accurate value of 2. Thus, this change updates the estimate_num_wgs_per_tile function to explicitely return the value of 2 in cases for DP+2TSK to ensure that we have a better error tolerance to avoid test failures due to round-off error.
This commit is contained in:
@@ -219,17 +219,27 @@ CK_TILE_HOST index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::estimate_num_wgs_per_tile()
|
||||
const noexcept
|
||||
{
|
||||
// In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
|
||||
// writing final results to a given macro tile in C.
|
||||
// In the case of non-atomic reduction or data-parallel (DP) only, there will always be 1
|
||||
// workgroup writing final results to a given macro tile in C.
|
||||
int num_wgs_per_tile = 1;
|
||||
|
||||
// Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
|
||||
if(sk_ctas_ > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1);
|
||||
// Estimate the number of workgroups per macro tile.
|
||||
num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
|
||||
((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
|
||||
// If we have DP and SK tiles, this is DP+2TSK which guarantees at most 2 workgroups per
|
||||
// tile. We only need to check that dp_tiles is greater than zero since we know we have SK
|
||||
// workgroups.
|
||||
if(dp_tiles_ > 0)
|
||||
{
|
||||
num_wgs_per_tile = 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1);
|
||||
// Estimate the number of workgroups per macro tile.
|
||||
num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
|
||||
((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
|
||||
}
|
||||
}
|
||||
|
||||
return std::max(num_wgs_per_tile, 1);
|
||||
|
||||
Reference in New Issue
Block a user