Merge commit 'b03764ca5a917752845ddbb5da8886051a16d9be' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-17 17:11:18 +00:00
parent 99ccb97fad
commit f2f7a548cb
15 changed files with 172 additions and 80 deletions

View File

@@ -10,6 +10,7 @@
#include <tuple>
#include "ck_tile/host.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
@@ -50,10 +51,10 @@ class TestCkTileStreamK : public ::testing::Test
bool PadK = true,
bool Preshuffle = false,
bool TransposeC = false>
bool invoke_streamk(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& s,
int num_cu,
int occupancy)
std::tuple<bool, ck_tile::index_t> invoke_streamk(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& s,
int num_cu,
int occupancy)
{
constexpr bool kPadM = PadM;
constexpr bool kPadN = PadN;
@@ -129,7 +130,7 @@ class TestCkTileStreamK : public ::testing::Test
if(!Kernel::IsSupportedArgument(kargs))
{
return false;
return std::tuple{false, -1};
}
dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner);
@@ -138,7 +139,16 @@ class TestCkTileStreamK : public ::testing::Test
ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grid_dims, block_dims, 0, kargs));
return true;
ck_tile::index_t num_accumulations_per_tile =
ck_tile::estimate_num_wgs_per_tile<ReductionStrategy>(
kargs.tile_partitioner.sk_num_blocks,
// k_iters_per_big_block could be 1, which indicates that all blocks are
// big and each does one iteration. Thus, we ensure the value passed in is at
// least 1 to avoid division by zero errors.
ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u),
kargs.tile_partitioner.k_iters_per_tile.get());
return std::tuple{true, num_accumulations_per_tile};
};
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
@@ -238,8 +248,11 @@ class TestCkTileStreamK : public ::testing::Test
reduction_strategy,
num_sk_blocks};
if(!invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy))
const auto [is_valid_instance, num_accumulations_per_tile] =
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy);
if(!is_valid_instance)
{
GTEST_SKIP() << "Skipping this test: The kernel cannot solve the problem\n";
}
@@ -256,7 +269,7 @@ class TestCkTileStreamK : public ::testing::Test
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, /*kbatch*/ 1, max_accumulated_value);
K, num_accumulations_per_tile, max_accumulated_value);
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,