mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Merge commit 'b03764ca5a917752845ddbb5da8886051a16d9be' into develop
This commit is contained in:
@@ -2,29 +2,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
|
||||
// Estimate the number of WGs contributing to the same macro tile in C
|
||||
template <ck_tile::StreamKReductionStrategy ReductionStrategy, typename TilePartitioner>
|
||||
int estimate_num_wgs_per_tile(const TilePartitioner& tile_partitioner)
|
||||
{
|
||||
// In the case of non-atomic reduction or DP only, there will always be 1 WG contributing to a
|
||||
// macro time in C
|
||||
int num_wgs_per_tile = 1;
|
||||
|
||||
// Otherwise, for atomics, multiple WGs may be contributing to the same macro tile in C
|
||||
if(tile_partitioner.sk_num_blocks > 0 &&
|
||||
ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// Determine the number of iterations per WG for a given macro tile in C
|
||||
uint32_t k_iters_per_block = tile_partitioner.k_iters_per_big_block - 1;
|
||||
|
||||
// Estimate the number of WGs per macro tile
|
||||
num_wgs_per_tile = (tile_partitioner.k_iters_per_tile.get() / (k_iters_per_block)) +
|
||||
((tile_partitioner.k_iters_per_tile.get() % k_iters_per_block) != 0);
|
||||
}
|
||||
|
||||
return std::max(num_wgs_per_tile, 1);
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout)
|
||||
{
|
||||
@@ -65,7 +42,8 @@ template <typename GemmConfig,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s);
|
||||
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s);
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
@@ -78,20 +56,21 @@ template <typename GemmConfig,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
std::tuple<float, int> invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
bool flush_cache,
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy,
|
||||
uint32_t num_sk_blocks)
|
||||
std::tuple<float, ck_tile::index_t>
|
||||
invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
bool flush_cache,
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy,
|
||||
uint32_t num_sk_blocks)
|
||||
{
|
||||
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
@@ -105,7 +84,7 @@ std::tuple<float, int> invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
reduction_strategy,
|
||||
num_sk_blocks};
|
||||
|
||||
std::tuple<float, int> ave_time_and_batch;
|
||||
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
|
||||
|
||||
if(args.reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "gemm_utils.hpp"
|
||||
#include "run_gemm_example.inc"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
@@ -16,7 +17,8 @@ template <typename GemmConfig,
|
||||
typename ELayout,
|
||||
typename CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s)
|
||||
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
@@ -42,7 +44,7 @@ std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile:
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation) -> std::tuple<float, int> {
|
||||
const auto Run = [&](const auto memory_operation) -> std::tuple<float, ck_tile::index_t> {
|
||||
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
|
||||
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
|
||||
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
|
||||
@@ -113,7 +115,13 @@ std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile:
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
int num_wgs_per_tile = estimate_num_wgs_per_tile<ReductionStrategy>(kargs.tile_partitioner);
|
||||
ck_tile::index_t num_wgs_per_tile = ck_tile::estimate_num_wgs_per_tile<ReductionStrategy>(
|
||||
kargs.tile_partitioner.sk_num_blocks,
|
||||
// k_iters_per_big_block could be 1, which indicates that all Stream-K workgroups are
|
||||
// big and each does one iteration. Thus, we ensure the value passed in is at least 1 to
|
||||
// avoid division by zero errors.
|
||||
ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u),
|
||||
kargs.tile_partitioner.k_iters_per_tile.get());
|
||||
|
||||
return std::tuple{ave_time, num_wgs_per_tile};
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import pathlib
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
@@ -10,8 +11,12 @@ for p in sorted(Path("./").rglob("*")):
|
||||
|
||||
# formatting
|
||||
for x in all_files:
|
||||
subprocess.Popen(f"dos2unix -n {str(x)}", shell=True)
|
||||
cmd = f"clang-format-18 -style=file -i {str(x)}"
|
||||
subprocess.Popen(
|
||||
f"python -m dos2unix {str(x)} {str(x)}",
|
||||
shell=True,
|
||||
stdout=open(os.devnull, "wb"),
|
||||
)
|
||||
cmd = f"clang-format -style=file -i {str(x)}"
|
||||
# for xp in x.parents:
|
||||
# print(get_file_base(x))
|
||||
subprocess.Popen(cmd, shell=True)
|
||||
|
||||
Reference in New Issue
Block a user