mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Fix CK Tile Stream-K BF16 Validation Errors (#3039)
Prior to this change, the number of accumulations passed into calculate_rtol_atol was 1. That said, in most cases, this is not correct when there are multiple workgroups contributing to the same macro tile in C. This change ensures uses the function estimate_num_wgs_per_tile, which was extracted into a common file and generalized, to estimate the number of workgroups per macro tile. This estimate is passed into calculate_rtol_atol to ensure we get a better relative and absolute tolerance.
This commit is contained in:
@@ -11,4 +11,33 @@ enum StreamKReductionStrategy : uint32_t
|
||||
Atomic = 0u,
|
||||
Reduction = 1u
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Estimates the number of Stream-K workgroups per macro tile in the C tensor.
|
||||
*
|
||||
* @param sk_ctas Number of Stream-K workgroups.
|
||||
* @param iters_per_sk_cta Number of iterations per Stream-K workgroup.
|
||||
* @param iters_per_tile Number of iterations per tile (i.e., the number of macro tiles in the K
|
||||
* dimension).
|
||||
* @return ck_tile::index_t An estimate of the number of workgroups per macro tile in the C tensor.
|
||||
* @note It is assumed that `iters_per_sk_cta` > 0.
|
||||
*/
|
||||
template <ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
ck_tile::index_t
|
||||
estimate_num_wgs_per_tile(index_t sk_ctas, index_t iters_per_sk_cta, index_t iters_per_tile)
|
||||
{
|
||||
// In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
|
||||
// writing final results to a given macro tile in C.
|
||||
int num_wgs_per_tile = 1;
|
||||
|
||||
// Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
|
||||
if(sk_ctas > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// Estimate the number of workgroups per macro tile.
|
||||
num_wgs_per_tile =
|
||||
(iters_per_tile / iters_per_sk_cta) + ((iters_per_tile % iters_per_sk_cta) != 0);
|
||||
}
|
||||
|
||||
return std::max(num_wgs_per_tile, 1);
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user