mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Fix CK Tile Stream-K BF16 Validation Errors (#3039)
Prior to this change, the number of accumulations passed into
calculate_rtol_atol was 1. That said, in most cases, this is not correct
when there are multiple workgroups contributing to the same macro tile
in C.
This change ensures uses the function estimate_num_wgs_per_tile, which
was extracted into a common file and generalized, to estimate the number
of workgroups per macro tile. This estimate is passed into
calculate_rtol_atol to ensure we get a better relative and absolute
tolerance.
[ROCm/composable_kernel commit: 352dee5225]
This commit is contained in:
@@ -2,29 +2,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
// Estimate the number of WGs contributing to the same macro tile in C
|
||||
template <ck_tile::StreamKReductionStrategy ReductionStrategy, typename TilePartitioner>
|
||||
int estimate_num_wgs_per_tile(const TilePartitioner& tile_partitioner)
|
||||
{
|
||||
// In the case of non-atomic reduction or DP only, there will always be 1 WG contributing to a
|
||||
// macro time in C
|
||||
int num_wgs_per_tile = 1;
|
||||
|
||||
// Otherwise, for atomics, multiple WGs may be contributing to the same macro tile in C
|
||||
if(tile_partitioner.sk_num_blocks > 0 &&
|
||||
ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// Determine the number of iterations per WG for a given macro tile in C
|
||||
uint32_t k_iters_per_block = tile_partitioner.k_iters_per_big_block - 1;
|
||||
|
||||
// Estimate the number of WGs per macro tile
|
||||
num_wgs_per_tile = (tile_partitioner.k_iters_per_tile.get() / (k_iters_per_block)) +
|
||||
((tile_partitioner.k_iters_per_tile.get() % k_iters_per_block) != 0);
|
||||
}
|
||||
|
||||
return std::max(num_wgs_per_tile, 1);
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout)
|
||||
{
|
||||
@@ -65,7 +42,8 @@ template <typename GemmConfig,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s);
|
||||
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s);
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
@@ -78,20 +56,21 @@ template <typename GemmConfig,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
std::tuple<float, int> invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
bool flush_cache,
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy,
|
||||
uint32_t num_sk_blocks)
|
||||
std::tuple<float, ck_tile::index_t>
|
||||
invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
bool flush_cache,
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy,
|
||||
uint32_t num_sk_blocks)
|
||||
{
|
||||
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
@@ -105,7 +84,7 @@ std::tuple<float, int> invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
reduction_strategy,
|
||||
num_sk_blocks};
|
||||
|
||||
std::tuple<float, int> ave_time_and_batch;
|
||||
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
|
||||
|
||||
if(args.reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "gemm_utils.hpp"
|
||||
#include "run_gemm_example.inc"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
@@ -16,7 +17,8 @@ template <typename GemmConfig,
|
||||
typename ELayout,
|
||||
typename CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s)
|
||||
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
@@ -42,7 +44,7 @@ std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile:
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation) -> std::tuple<float, int> {
|
||||
const auto Run = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
|
||||
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
|
||||
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
|
||||
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
|
||||
@@ -113,7 +115,13 @@ std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile:
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
int num_wgs_per_tile = estimate_num_wgs_per_tile<ReductionStrategy>(kargs.tile_partitioner);
|
||||
ck_tile::index_t num_wgs_per_tile = ck_tile::estimate_num_wgs_per_tile<ReductionStrategy>(
|
||||
kargs.tile_partitioner.sk_num_blocks,
|
||||
// k_iters_per_big_block could be 1, which indicates that all Stream-K workgroups are
|
||||
// big and each does one iteration. Thus, we ensure the value passed in is at least 1 to
|
||||
// avoid division by zero errors.
|
||||
ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u),
|
||||
kargs.tile_partitioner.k_iters_per_tile.get());
|
||||
|
||||
return std::tuple{ave_time, num_wgs_per_tile};
|
||||
};
|
||||
|
||||
@@ -11,4 +11,33 @@ enum StreamKReductionStrategy : uint32_t
|
||||
Atomic = 0u,
|
||||
Reduction = 1u
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Estimates the number of Stream-K workgroups per macro tile in the C tensor.
|
||||
*
|
||||
* @param sk_ctas Number of Stream-K workgroups.
|
||||
* @param iters_per_sk_cta Number of iterations per Stream-K workgroup.
|
||||
* @param iters_per_tile Number of iterations per tile (i.e., the number of macro tiles in the K
|
||||
* dimension).
|
||||
* @return ck_tile::index_t An estimate of the number of workgroups per macro tile in the C tensor.
|
||||
* @note It is assumed that `iters_per_sk_cta` > 0.
|
||||
*/
|
||||
template <ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
ck_tile::index_t
|
||||
estimate_num_wgs_per_tile(index_t sk_ctas, index_t iters_per_sk_cta, index_t iters_per_tile)
|
||||
{
|
||||
// In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
|
||||
// writing final results to a given macro tile in C.
|
||||
int num_wgs_per_tile = 1;
|
||||
|
||||
// Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
|
||||
if(sk_ctas > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// Estimate the number of workgroups per macro tile.
|
||||
num_wgs_per_tile =
|
||||
(iters_per_tile / iters_per_sk_cta) + ((iters_per_tile % iters_per_sk_cta) != 0);
|
||||
}
|
||||
|
||||
return std::max(num_wgs_per_tile, 1);
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
@@ -50,10 +51,10 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
bool PadK = true,
|
||||
bool Preshuffle = false,
|
||||
bool TransposeC = false>
|
||||
bool invoke_streamk(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s,
|
||||
int num_cu,
|
||||
int occupancy)
|
||||
std::tuple<bool, ck_tile::index_t> invoke_streamk(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s,
|
||||
int num_cu,
|
||||
int occupancy)
|
||||
{
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
@@ -129,7 +130,7 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return false;
|
||||
return std::tuple{false, -1};
|
||||
}
|
||||
|
||||
dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner);
|
||||
@@ -138,7 +139,16 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grid_dims, block_dims, 0, kargs));
|
||||
|
||||
return true;
|
||||
ck_tile::index_t num_accumulations_per_tile =
|
||||
ck_tile::estimate_num_wgs_per_tile<ReductionStrategy>(
|
||||
kargs.tile_partitioner.sk_num_blocks,
|
||||
// k_iters_per_big_block could be 1, which indicates that all blocks are
|
||||
// big and each does one iteration. Thus, we ensure the value passed in is at
|
||||
// least 1 to avoid division by zero errors.
|
||||
ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u),
|
||||
kargs.tile_partitioner.k_iters_per_tile.get());
|
||||
|
||||
return std::tuple{true, num_accumulations_per_tile};
|
||||
};
|
||||
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
@@ -238,8 +248,11 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
reduction_strategy,
|
||||
num_sk_blocks};
|
||||
|
||||
if(!invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy))
|
||||
const auto [is_valid_instance, num_accumulations_per_tile] =
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy);
|
||||
|
||||
if(!is_valid_instance)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping this test: The kernel cannot solve the problem\n";
|
||||
}
|
||||
@@ -256,7 +269,7 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, /*kbatch*/ 1, max_accumulated_value);
|
||||
K, num_accumulations_per_tile, max_accumulated_value);
|
||||
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
|
||||
Reference in New Issue
Block a user