mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Fix CK Tile Stream-K BF16 Validation Errors (#3039)
Prior to this change, the number of accumulations passed into calculate_rtol_atol was 1. That said, in most cases, this is not correct when there are multiple workgroups contributing to the same macro tile in C. This change ensures uses the function estimate_num_wgs_per_tile, which was extracted into a common file and generalized, to estimate the number of workgroups per macro tile. This estimate is passed into calculate_rtol_atol to ensure we get a better relative and absolute tolerance.
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
@@ -50,10 +51,10 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
bool PadK = true,
|
||||
bool Preshuffle = false,
|
||||
bool TransposeC = false>
|
||||
bool invoke_streamk(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s,
|
||||
int num_cu,
|
||||
int occupancy)
|
||||
std::tuple<bool, ck_tile::index_t> invoke_streamk(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s,
|
||||
int num_cu,
|
||||
int occupancy)
|
||||
{
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
@@ -129,7 +130,7 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return false;
|
||||
return std::tuple{false, -1};
|
||||
}
|
||||
|
||||
dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner);
|
||||
@@ -138,7 +139,16 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grid_dims, block_dims, 0, kargs));
|
||||
|
||||
return true;
|
||||
ck_tile::index_t num_accumulations_per_tile =
|
||||
ck_tile::estimate_num_wgs_per_tile<ReductionStrategy>(
|
||||
kargs.tile_partitioner.sk_num_blocks,
|
||||
// k_iters_per_big_block could be 1, which indicates that all blocks are
|
||||
// big and each does one iteration. Thus, we ensure the value passed in is at
|
||||
// least 1 to avoid division by zero errors.
|
||||
ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u),
|
||||
kargs.tile_partitioner.k_iters_per_tile.get());
|
||||
|
||||
return std::tuple{true, num_accumulations_per_tile};
|
||||
};
|
||||
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
@@ -238,8 +248,11 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
reduction_strategy,
|
||||
num_sk_blocks};
|
||||
|
||||
if(!invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy))
|
||||
const auto [is_valid_instance, num_accumulations_per_tile] =
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy);
|
||||
|
||||
if(!is_valid_instance)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping this test: The kernel cannot solve the problem\n";
|
||||
}
|
||||
@@ -256,7 +269,7 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, /*kbatch*/ 1, max_accumulated_value);
|
||||
K, num_accumulations_per_tile, max_accumulated_value);
|
||||
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
|
||||
Reference in New Issue
Block a user