[rocm-libraries] ROCm/rocm-libraries#4795 (commit 6590a1a)

[CK_TILE] Rename Stream-K grid function

## Motivation
This PR introduces a change in the name of the get_grid function in the
Stream-K TilePartitioner to avoid confusion with a similarly named
method. In the Stream-K TilePartitioner, there is get_grid() which
returns num_cu*occupancy and there is grid_size() which returns the grid
size used to launch the kernel. In this PR, we change get_grid() to be
get_max_active_wgs() to better reflect what the function returns and not
confuse it with grid_size().

## Technical Details
Initially in the Stream-K TilePartitioner we had get_grid() which
returned grid_. We are renaming get_grid() to get_max_active_wgs() and
grid_ to max_active_wgs_ internally, while keeping grid_size() the same.
The parameter, grid, for the Stream-K TilePartitioner remains the same
to maintain consistency with the rest of the Stream-K API.

## Test Plan
Validated using the test suite that is already present.

## Test Result
All tests passed

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
arai713
2026-03-20 09:28:47 +00:00
committed by assistant-librarian[bot]
parent a268a2a2e1
commit da863dae1b
5 changed files with 108 additions and 105 deletions

View File

@@ -119,7 +119,7 @@ struct StreamKKernel
struct StreamKKernelArgs : ck_tile::UniversalGemmKernelArgs<>
{
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid)
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t max_active_wgs)
: UniversalGemmKernelArgs{host_args.as_ptr,
host_args.bs_ptr,
host_args.ds_ptr,
@@ -135,7 +135,8 @@ struct StreamKKernel
// The workspace pointer is set to nullptr because we must first
// instantiate the TilePartitioner to get the necessary size
workspace_ptr{nullptr},
tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
tile_partitioner{
TilePartitioner{host_args.M, host_args.N, host_args.K, max_active_wgs}}
{
}
@@ -206,9 +207,9 @@ struct StreamKKernel
int num_cu = NumCU(),
int occupancy = Occupancy())
{
const index_t grid = num_cu * occupancy;
const index_t max_active_wgs = num_cu * occupancy;
return StreamKKernelArgs{host_args, grid};
return StreamKKernelArgs{host_args, max_active_wgs};
}
template <bool UseDefaultScheduler = true>
@@ -790,7 +791,7 @@ struct StreamKKernel
// Data-parallel section
for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
tile_idx += kargs.tile_partitioner.get_grid())
tile_idx += kargs.tile_partitioner.get_max_active_wgs())
{
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
block_sync_lds();

View File

@@ -31,7 +31,7 @@ struct StreamKTilePartitionerBase
? memory_operation_enum::atomic_add
: memory_operation_enum::set;
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t max_number_wgs);
/**
* @brief Calculates the total space needed for the partials buffer.
@@ -156,7 +156,7 @@ struct StreamKTilePartitionerBase
* @brief Returns the maximum number of active workgroups; this is assumed to be number of CUs *
* occupancy.
*/
CK_TILE_HOST_DEVICE index_t get_grid() const noexcept;
CK_TILE_HOST_DEVICE index_t get_max_active_wgs() const noexcept;
/**
* @brief Returns the number of tiles in the C tensor that will use the data-parallel (DP)
@@ -215,7 +215,7 @@ struct StreamKTilePartitionerBase
protected:
index_t num_tiles_;
index_t grid_;
index_t max_active_wgs_;
index_t dp_tiles_;
private:
@@ -270,7 +270,7 @@ struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>
StreamKTilePartitioner(ck_tile::index_t m,
ck_tile::index_t n,
ck_tile::index_t k,
ck_tile::index_t grid);
ck_tile::index_t max_active_wgs);
public:
static constexpr bool PERSISTENT = true;
@@ -290,7 +290,7 @@ struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>
/**
* @brief Returns the total number of DP tiles left over when `dp_tiles_` is not evenly
* divisible by `grid_`.
* divisible by `max_active_wgs_`.
*/
CK_TILE_HOST_DEVICE index_t get_extra_dp_tiles() const noexcept;
@@ -317,7 +317,7 @@ struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>
StreamKTilePartitioner(ck_tile::index_t m,
ck_tile::index_t n,
ck_tile::index_t k,
ck_tile::index_t grid);
ck_tile::index_t max_number_wgs);
public:
static constexpr bool PERSISTENT = false;

View File

@@ -7,24 +7,25 @@ namespace ck_tile {
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::StreamKTilePartitionerBase(
index_t m, index_t n, index_t k, index_t grid)
: grid_{grid}, n_{n}
index_t m, index_t n, index_t k, index_t max_active_wgs)
: max_active_wgs_{max_active_wgs}, n_{n}
{
iters_per_tile_ = integer_divide_ceil(k, KPerBlock);
num_tiles_ = integer_divide_ceil(m, MPerBlock) * integer_divide_ceil(n_, NPerBlock);
bool big_enough = num_tiles_ > grid_;
index_t remainder_tiles = num_tiles_ % grid_;
bool big_enough = num_tiles_ > max_active_wgs_;
index_t remainder_tiles = num_tiles_ % max_active_wgs_;
if(remainder_tiles)
{
sk_tiles_ = big_enough ? full_tiles_ * grid_ + (num_tiles_ % grid_) : num_tiles_;
sk_tiles_ = min(num_tiles_, sk_tiles_);
sk_ctas_ = grid_;
sk_tiles_ = big_enough ? full_tiles_ * max_active_wgs_ + (num_tiles_ % max_active_wgs_)
: num_tiles_;
sk_tiles_ = min(num_tiles_, sk_tiles_);
sk_ctas_ = max_active_wgs_;
total_sk_iters_ = sk_tiles_ * iters_per_tile_;
// If there still isn't enough work to saturate all CUs, then just revert to DP only.
if(total_sk_iters_ < grid_)
if(total_sk_iters_ < max_active_wgs_)
{
sk_tiles_ = 0;
sk_ctas_ = 0;
@@ -175,9 +176,10 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_num_t
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_grid() const noexcept
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_max_active_wgs()
const noexcept
{
return grid_;
return max_active_wgs_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
@@ -287,11 +289,11 @@ struct StreamKTilePartitioner;
// child class for Persistent Tile Partitioner
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::StreamKTilePartitioner(
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t max_active_wgs)
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, max_active_wgs)
{ // inherit from base constructor
dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_;
extra_dp_tiles_ = this->dp_tiles_ % this->grid_;
dp_tiles_per_cta_ = this->dp_tiles_ / this->max_active_wgs_;
extra_dp_tiles_ = this->dp_tiles_ % this->max_active_wgs_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
@@ -301,7 +303,7 @@ StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::grid_si
{
if(extra_dp_tiles_ == 0)
{
return dim3(this->grid_, 1, 1);
return dim3(this->max_active_wgs_, 1, 1);
}
else
{
@@ -328,8 +330,8 @@ StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::get_ext
// child class for Non-Persistent Tile Partitioner
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>::StreamKTilePartitioner(
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t max_active_wgs)
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, max_active_wgs)
{ // inherit from base constructor
dp_ctas_ = this->dp_tiles_;
dp_start_block_idx_ = 0;