Merge commit 'b03764ca5a917752845ddbb5da8886051a16d9be' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-17 17:11:18 +00:00
parent 99ccb97fad
commit f2f7a548cb
15 changed files with 172 additions and 80 deletions

View File

@@ -2,29 +2,6 @@
// SPDX-License-Identifier: MIT
#pragma once
// Estimate the number of WGs contributing to the same macro tile in C
template <ck_tile::StreamKReductionStrategy ReductionStrategy, typename TilePartitioner>
int estimate_num_wgs_per_tile(const TilePartitioner& tile_partitioner)
{
// In the case of non-atomic reduction or DP only, there will always be 1 WG contributing to a
// macro time in C
int num_wgs_per_tile = 1;
// Otherwise, for atomics, multiple WGs may be contributing to the same macro tile in C
if(tile_partitioner.sk_num_blocks > 0 &&
ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
{
// Determine the number of iterations per WG for a given macro tile in C
uint32_t k_iters_per_block = tile_partitioner.k_iters_per_big_block - 1;
// Estimate the number of WGs per macro tile
num_wgs_per_tile = (tile_partitioner.k_iters_per_tile.get() / (k_iters_per_block)) +
((tile_partitioner.k_iters_per_tile.get() % k_iters_per_block) != 0);
}
return std::max(num_wgs_per_tile, 1);
}
template <typename Layout>
static constexpr inline auto is_row_major(Layout)
{
@@ -65,7 +42,8 @@ template <typename GemmConfig,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s);
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& s);
template <typename GemmConfig,
typename ADataType,
@@ -78,20 +56,21 @@ template <typename GemmConfig,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
std::tuple<float, int> invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
int n_warmup,
int n_repeat,
bool flush_cache,
ck_tile::StreamKReductionStrategy reduction_strategy,
uint32_t num_sk_blocks)
std::tuple<float, ck_tile::index_t>
invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
int n_warmup,
int n_repeat,
bool flush_cache,
ck_tile::StreamKReductionStrategy reduction_strategy,
uint32_t num_sk_blocks)
{
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
@@ -105,7 +84,7 @@ std::tuple<float, int> invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
reduction_strategy,
num_sk_blocks};
std::tuple<float, int> ave_time_and_batch;
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
if(args.reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
{

View File

@@ -3,6 +3,7 @@
#include "gemm_utils.hpp"
#include "run_gemm_example.inc"
#include "ck_tile/ops/common.hpp"
template <typename GemmConfig,
typename ADataType,
@@ -16,7 +17,8 @@ template <typename GemmConfig,
typename ELayout,
typename CDEElementWise,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s)
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
@@ -42,7 +44,7 @@ std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile:
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
const auto Run = [&](const auto memory_operation) -> std::tuple<float, int> {
const auto Run = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
@@ -113,7 +115,13 @@ std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile:
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
int num_wgs_per_tile = estimate_num_wgs_per_tile<ReductionStrategy>(kargs.tile_partitioner);
ck_tile::index_t num_wgs_per_tile = ck_tile::estimate_num_wgs_per_tile<ReductionStrategy>(
kargs.tile_partitioner.sk_num_blocks,
// k_iters_per_big_block could be 1, which indicates that all Stream-K workgroups are
// big and each does one iteration. Thus, we ensure the value passed in is at least 1 to
// avoid division by zero errors.
ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u),
kargs.tile_partitioner.k_iters_per_tile.get());
return std::tuple{ave_time, num_wgs_per_tile};
};

View File

@@ -1,3 +1,4 @@
import os
import pathlib
from pathlib import Path
import subprocess
@@ -10,8 +11,12 @@ for p in sorted(Path("./").rglob("*")):
# formatting
for x in all_files:
subprocess.Popen(f"dos2unix -n {str(x)}", shell=True)
cmd = f"clang-format-18 -style=file -i {str(x)}"
subprocess.Popen(
f"python -m dos2unix {str(x)} {str(x)}",
shell=True,
stdout=open(os.devnull, "wb"),
)
cmd = f"clang-format -style=file -i {str(x)}"
# for xp in x.parents:
# print(get_file_base(x))
subprocess.Popen(cmd, shell=True)