mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Merge commit 'b03764ca5a917752845ddbb5da8886051a16d9be' into develop
This commit is contained in:
16
.github/workflows/pre-commit.yml
vendored
Normal file
16
.github/workflows/pre-commit.yml
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
name: pre-commit
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches: [develop]
|
||||
|
||||
jobs:
|
||||
pre-commit:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: '3.12'
|
||||
- uses: pre-commit/action@v3.0.1
|
||||
@@ -32,9 +32,12 @@ repos:
|
||||
language: script
|
||||
types_or: [c++, text]
|
||||
verbose: true
|
||||
- id: run-remod-if-ck-tile-changed
|
||||
name: Run remod.py if ck_tile files changed
|
||||
entry: script/remod_for_ck_tile.sh
|
||||
language: script
|
||||
- id: remod-ck-tile
|
||||
name: Run ck_tile remod.py
|
||||
entry: python script/remod_for_ck_tile.py
|
||||
language: python
|
||||
files: '^(include|example)/ck_tile/.*$'
|
||||
additional_dependencies:
|
||||
- dos2unix
|
||||
- clang-format==18.1.3
|
||||
pass_filenames: false
|
||||
|
||||
@@ -2,29 +2,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
// Estimate the number of WGs contributing to the same macro tile in C
|
||||
template <ck_tile::StreamKReductionStrategy ReductionStrategy, typename TilePartitioner>
|
||||
int estimate_num_wgs_per_tile(const TilePartitioner& tile_partitioner)
|
||||
{
|
||||
// In the case of non-atomic reduction or DP only, there will always be 1 WG contributing to a
|
||||
// macro time in C
|
||||
int num_wgs_per_tile = 1;
|
||||
|
||||
// Otherwise, for atomics, multiple WGs may be contributing to the same macro tile in C
|
||||
if(tile_partitioner.sk_num_blocks > 0 &&
|
||||
ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// Determine the number of iterations per WG for a given macro tile in C
|
||||
uint32_t k_iters_per_block = tile_partitioner.k_iters_per_big_block - 1;
|
||||
|
||||
// Estimate the number of WGs per macro tile
|
||||
num_wgs_per_tile = (tile_partitioner.k_iters_per_tile.get() / (k_iters_per_block)) +
|
||||
((tile_partitioner.k_iters_per_tile.get() % k_iters_per_block) != 0);
|
||||
}
|
||||
|
||||
return std::max(num_wgs_per_tile, 1);
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout)
|
||||
{
|
||||
@@ -65,7 +42,8 @@ template <typename GemmConfig,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s);
|
||||
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s);
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
@@ -78,20 +56,21 @@ template <typename GemmConfig,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
std::tuple<float, int> invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
bool flush_cache,
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy,
|
||||
uint32_t num_sk_blocks)
|
||||
std::tuple<float, ck_tile::index_t>
|
||||
invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
bool flush_cache,
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy,
|
||||
uint32_t num_sk_blocks)
|
||||
{
|
||||
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
@@ -105,7 +84,7 @@ std::tuple<float, int> invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
reduction_strategy,
|
||||
num_sk_blocks};
|
||||
|
||||
std::tuple<float, int> ave_time_and_batch;
|
||||
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
|
||||
|
||||
if(args.reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "gemm_utils.hpp"
|
||||
#include "run_gemm_example.inc"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
@@ -16,7 +17,8 @@ template <typename GemmConfig,
|
||||
typename ELayout,
|
||||
typename CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s)
|
||||
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
@@ -42,7 +44,7 @@ std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile:
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation) -> std::tuple<float, int> {
|
||||
const auto Run = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
|
||||
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
|
||||
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
|
||||
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
|
||||
@@ -113,7 +115,13 @@ std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile:
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
int num_wgs_per_tile = estimate_num_wgs_per_tile<ReductionStrategy>(kargs.tile_partitioner);
|
||||
ck_tile::index_t num_wgs_per_tile = ck_tile::estimate_num_wgs_per_tile<ReductionStrategy>(
|
||||
kargs.tile_partitioner.sk_num_blocks,
|
||||
// k_iters_per_big_block could be 1, which indicates that all Stream-K workgroups are
|
||||
// big and each does one iteration. Thus, we ensure the value passed in is at least 1 to
|
||||
// avoid division by zero errors.
|
||||
ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u),
|
||||
kargs.tile_partitioner.k_iters_per_tile.get());
|
||||
|
||||
return std::tuple{ave_time, num_wgs_per_tile};
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import pathlib
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
@@ -10,8 +11,12 @@ for p in sorted(Path("./").rglob("*")):
|
||||
|
||||
# formatting
|
||||
for x in all_files:
|
||||
subprocess.Popen(f"dos2unix -n {str(x)}", shell=True)
|
||||
cmd = f"clang-format-18 -style=file -i {str(x)}"
|
||||
subprocess.Popen(
|
||||
f"python -m dos2unix {str(x)} {str(x)}",
|
||||
shell=True,
|
||||
stdout=open(os.devnull, "wb"),
|
||||
)
|
||||
cmd = f"clang-format -style=file -i {str(x)}"
|
||||
# for xp in x.parents:
|
||||
# print(get_file_base(x))
|
||||
subprocess.Popen(cmd, shell=True)
|
||||
|
||||
@@ -73,7 +73,7 @@ struct Max
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::min();
|
||||
return numeric<T>::lowest();
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
@@ -96,7 +96,7 @@ struct AbsMax
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::min();
|
||||
return numeric<T>::lowest();
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
|
||||
@@ -6,6 +6,12 @@
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
// GPU kernel to invalidate instruction cache for accurate benchmarking.
|
||||
// s_icache_inv: Asynchronously invalidates the L1 instruction cache on this compute unit,
|
||||
// forcing subsequent kernel runs to fetch instructions from HBM instead of cache.
|
||||
// 16x s_nop: Wait cycles (~16 cycles) to ensure cache invalidation completes before kernel
|
||||
// exits. Without these NOPs, the flush may not finish, leading to inconsistent
|
||||
// timing measurements where some instructions remain cached.
|
||||
static __global__ void flush_cache()
|
||||
{
|
||||
asm __volatile__("s_icache_inv \n\t"
|
||||
|
||||
@@ -9,6 +9,20 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// RotatingMemWrapper: Prevents GPU data cache reuse during kernel benchmarking.
|
||||
//
|
||||
// Purpose:
|
||||
// When benchmarking a kernel repeatedly with the same input buffers, the GPU L2 cache
|
||||
// will serve data from cache (hot) instead of HBM (cold), leading to artificially fast
|
||||
// timing measurements. This wrapper rotates through multiple copies of buffers at different
|
||||
// memory addresses to force cache misses.
|
||||
//
|
||||
// How it works:
|
||||
// Constructor: Creates rotating_count copies of matrices A and B in GPU memory
|
||||
// Next(): Switches pointers to the next buffer copy (cycles through all copies)
|
||||
// Destructor: Frees extra buffer copies and restores original pointers
|
||||
//
|
||||
// Combined with flush_icache(), this ensures realistic "cold cache" performance measurements.
|
||||
template <typename ADataType, typename BDataType>
|
||||
struct RotatingMemWrapper
|
||||
{
|
||||
@@ -24,15 +38,18 @@ struct RotatingMemWrapper
|
||||
size_a(size_a_),
|
||||
size_b(size_b_)
|
||||
{
|
||||
// Store original buffer pointers as first entry
|
||||
p_a_grids.push_back(a_ptr);
|
||||
p_b_grids.push_back(b_ptr);
|
||||
|
||||
// Create (rotating_count - 1) additional copies at different memory addresses
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
{
|
||||
void* pADeviceBuf;
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pADeviceBuf),
|
||||
const_cast<void*>(p_a_grids[0]),
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pADeviceBuf), // target buffer
|
||||
const_cast<void*>(p_a_grids[0]), // source buffer
|
||||
size_a_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
p_a_grids.push_back(pADeviceBuf);
|
||||
@@ -41,19 +58,21 @@ struct RotatingMemWrapper
|
||||
{
|
||||
void* pBDeviceBuf;
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pBDeviceBuf),
|
||||
const_cast<void*>(p_b_grids[0]),
|
||||
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pBDeviceBuf), // target buffer
|
||||
const_cast<void*>(p_b_grids[0]), // source buffer
|
||||
size_b_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
p_b_grids.push_back(pBDeviceBuf);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Rotate to the next buffer copy. Call this before each kernel run to use different
|
||||
// memory addresses, forcing the GPU to fetch data from HBM instead of cache.
|
||||
void Next()
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
std::size_t idx = iter++ % rotating_count;
|
||||
std::size_t idx = iter++ % rotating_count; // Cycle through all buffer copies
|
||||
a_ptr = p_a_grids[idx];
|
||||
b_ptr = p_b_grids[idx];
|
||||
}
|
||||
@@ -63,15 +82,16 @@ struct RotatingMemWrapper
|
||||
std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
|
||||
<< ", rotating_count: " << rotating_count << "}" << std::endl;
|
||||
}
|
||||
// Cleanup: Free all extra buffer copies (keeping original) and restore original pointers
|
||||
~RotatingMemWrapper() noexcept
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
// restore ptr
|
||||
// Restore original buffer pointers
|
||||
a_ptr = p_a_grids[0];
|
||||
b_ptr = p_b_grids[0];
|
||||
|
||||
// free device mem
|
||||
// Free extra buffer copies (index 0 is the original, don't free it)
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
ck_tile::hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
|
||||
@@ -94,7 +114,12 @@ inline void flush_icache()
|
||||
{
|
||||
hipDeviceProp_t deviceProps;
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
|
||||
int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
|
||||
|
||||
// Over-provision blocks to ensure all CUs execute the flush instruction.
|
||||
// With imperfect scheduling, launching exactly 1 block per CU doesn't guarantee coverage.
|
||||
// 60x over-provisioning provides statistical certainty that every CU gets at least one block.
|
||||
constexpr int32_t blocks_per_cu = 60;
|
||||
int32_t gpu_block3 = deviceProps.multiProcessorCount * blocks_per_cu;
|
||||
|
||||
ck_tile::flush_cache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
|
||||
HIP_CHECK_ERROR(hipGetLastError());
|
||||
|
||||
@@ -11,4 +11,33 @@ enum StreamKReductionStrategy : uint32_t
|
||||
Atomic = 0u,
|
||||
Reduction = 1u
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Estimates the number of Stream-K workgroups per macro tile in the C tensor.
|
||||
*
|
||||
* @param sk_ctas Number of Stream-K workgroups.
|
||||
* @param iters_per_sk_cta Number of iterations per Stream-K workgroup.
|
||||
* @param iters_per_tile Number of iterations per tile (i.e., the number of macro tiles in the K
|
||||
* dimension).
|
||||
* @return ck_tile::index_t An estimate of the number of workgroups per macro tile in the C tensor.
|
||||
* @note It is assumed that `iters_per_sk_cta` > 0.
|
||||
*/
|
||||
template <ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
ck_tile::index_t
|
||||
estimate_num_wgs_per_tile(index_t sk_ctas, index_t iters_per_sk_cta, index_t iters_per_tile)
|
||||
{
|
||||
// In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
|
||||
// writing final results to a given macro tile in C.
|
||||
int num_wgs_per_tile = 1;
|
||||
|
||||
// Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
|
||||
if(sk_ctas > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// Estimate the number of workgroups per macro tile.
|
||||
num_wgs_per_tile =
|
||||
(iters_per_tile / iters_per_sk_cta) + ((iters_per_tile % iters_per_sk_cta) != 0);
|
||||
}
|
||||
|
||||
return std::max(num_wgs_per_tile, 1);
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -33,9 +33,10 @@
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
|
||||
|
||||
@@ -86,8 +86,12 @@ class submodule_t:
|
||||
submodule = submodule_t()
|
||||
# formatting
|
||||
for x in all_files:
|
||||
subprocess.Popen(f"dos2unix -n {str(x)}", shell=True)
|
||||
cmd = f"clang-format-18 -style=file -i {str(x)}"
|
||||
subprocess.Popen(
|
||||
f"python -m dos2unix {str(x)} {str(x)}",
|
||||
shell=True,
|
||||
stdout=open(os.devnull, "wb"),
|
||||
)
|
||||
cmd = f"clang-format -style=file -i {str(x)}"
|
||||
# for xp in x.parents:
|
||||
# print(get_file_base(x))
|
||||
subprocess.Popen(cmd, shell=True)
|
||||
|
||||
@@ -13,9 +13,6 @@ echo "I: Creating and activating virtual environment for pre-commit..."
|
||||
python3 -m venv "$(dirname "$0")/../.venv"
|
||||
source "$(dirname "$0")/../.venv/bin/activate"
|
||||
|
||||
echo "I: Installing tools required for pre-commit checks..."
|
||||
run_and_check pip install dos2unix
|
||||
run_and_check pip install clang-format==18.1.3
|
||||
echo "I: Installing pre-commit in virtual environment..."
|
||||
run_and_check pip install pre-commit
|
||||
run_and_check pre-commit install
|
||||
|
||||
13
script/remod_for_ck_tile.py
Executable file
13
script/remod_for_ck_tile.py
Executable file
@@ -0,0 +1,13 @@
|
||||
import os
|
||||
|
||||
root_dir = os.getcwd()
|
||||
ck_tile_include = root_dir + "/include/ck_tile"
|
||||
ck_tile_example = root_dir + "/example/ck_tile"
|
||||
|
||||
# Run for include
|
||||
os.chdir(ck_tile_include)
|
||||
_ = os.system("python remod.py")
|
||||
|
||||
# Run for example
|
||||
os.chdir(ck_tile_example)
|
||||
_ = os.system("python remod.py")
|
||||
@@ -1,7 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# Run remod.py in both required locations
|
||||
(cd include/ck_tile/ && python3 remod.py)
|
||||
(cd example/ck_tile/ && python3 remod.py)
|
||||
@@ -10,6 +10,7 @@
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
@@ -50,10 +51,10 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
bool PadK = true,
|
||||
bool Preshuffle = false,
|
||||
bool TransposeC = false>
|
||||
bool invoke_streamk(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s,
|
||||
int num_cu,
|
||||
int occupancy)
|
||||
std::tuple<bool, ck_tile::index_t> invoke_streamk(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s,
|
||||
int num_cu,
|
||||
int occupancy)
|
||||
{
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
@@ -129,7 +130,7 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
return false;
|
||||
return std::tuple{false, -1};
|
||||
}
|
||||
|
||||
dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner);
|
||||
@@ -138,7 +139,16 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grid_dims, block_dims, 0, kargs));
|
||||
|
||||
return true;
|
||||
ck_tile::index_t num_accumulations_per_tile =
|
||||
ck_tile::estimate_num_wgs_per_tile<ReductionStrategy>(
|
||||
kargs.tile_partitioner.sk_num_blocks,
|
||||
// k_iters_per_big_block could be 1, which indicates that all blocks are
|
||||
// big and each does one iteration. Thus, we ensure the value passed in is at
|
||||
// least 1 to avoid division by zero errors.
|
||||
ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u),
|
||||
kargs.tile_partitioner.k_iters_per_tile.get());
|
||||
|
||||
return std::tuple{true, num_accumulations_per_tile};
|
||||
};
|
||||
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
@@ -238,8 +248,11 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
reduction_strategy,
|
||||
num_sk_blocks};
|
||||
|
||||
if(!invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy))
|
||||
const auto [is_valid_instance, num_accumulations_per_tile] =
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy);
|
||||
|
||||
if(!is_valid_instance)
|
||||
{
|
||||
GTEST_SKIP() << "Skipping this test: The kernel cannot solve the problem\n";
|
||||
}
|
||||
@@ -256,7 +269,7 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, /*kbatch*/ 1, max_accumulated_value);
|
||||
K, num_accumulations_per_tile, max_accumulated_value);
|
||||
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
|
||||
Reference in New Issue
Block a user