mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
Merge commit 'b03764ca5a917752845ddbb5da8886051a16d9be' into develop
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
@@ -50,10 +51,10 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
bool PadK = true,
|
||||
bool Preshuffle = false,
|
||||
bool TransposeC = false>
|
||||
bool invoke_streamk(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s,
|
||||
int num_cu,
|
||||
int occupancy)
|
||||
std::tuple<bool, ck_tile::index_t> invoke_streamk(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s,
|
||||
int num_cu,
|
||||
int occupancy)
|
||||
{
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
@@ -129,7 +130,7 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return false;
|
||||
return std::tuple{false, -1};
|
||||
}
|
||||
|
||||
dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner);
|
||||
@@ -138,7 +139,16 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grid_dims, block_dims, 0, kargs));
|
||||
|
||||
return true;
|
||||
ck_tile::index_t num_accumulations_per_tile =
|
||||
ck_tile::estimate_num_wgs_per_tile<ReductionStrategy>(
|
||||
kargs.tile_partitioner.sk_num_blocks,
|
||||
// k_iters_per_big_block could be 1, which indicates that all blocks are
|
||||
// big and each does one iteration. Thus, we ensure the value passed in is at
|
||||
// least 1 to avoid division by zero errors.
|
||||
ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u),
|
||||
kargs.tile_partitioner.k_iters_per_tile.get());
|
||||
|
||||
return std::tuple{true, num_accumulations_per_tile};
|
||||
};
|
||||
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
@@ -238,8 +248,11 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
reduction_strategy,
|
||||
num_sk_blocks};
|
||||
|
||||
if(!invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy))
|
||||
const auto [is_valid_instance, num_accumulations_per_tile] =
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy);
|
||||
|
||||
if(!is_valid_instance)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping this test: The kernel cannot solve the problem\n";
|
||||
}
|
||||
@@ -256,7 +269,7 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, /*kbatch*/ 1, max_accumulated_value);
|
||||
K, num_accumulations_per_tile, max_accumulated_value);
|
||||
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
|
||||
Reference in New Issue
Block a user