[rocm-libraries] ROCm/rocm-libraries#4795 (commit 6590a1a)

[CK_TILE] Rename Stream-K grid function

## Motivation
This PR introduces a change in the name of the get_grid function in the
Stream-K TilePartitioner to avoid confusion with a similarly named
method. In the Stream-K TilePartitioner, there is get_grid() which
returns num_cu*occupancy and there is grid_size() which returns the grid
size used to launch the kernel. In this PR, we change get_grid() to be
get_max_active_wgs() to better reflect what the function returns and not
confuse it with grid_size().

## Technical Details
Initially in the Stream-K TilePartitioner we had get_grid() which
returned grid_. We are renaming get_grid() to get_max_active_wgs() and
grid_ to max_active_wgs_ internally, while keeping grid_size() the same.
The parameter, grid, for the Stream-K TilePartitioner remains the same
to maintain consistency with the rest of the Stream-K API.

## Test Plan
Validated using the test suite that is already present.

## Test Result
All tests passed

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
arai713
2026-03-20 09:28:47 +00:00
committed by assistant-librarian[bot]
parent a268a2a2e1
commit da863dae1b
5 changed files with 108 additions and 105 deletions

View File

@@ -119,7 +119,7 @@ struct StreamKKernel
struct StreamKKernelArgs : ck_tile::UniversalGemmKernelArgs<>
{
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid)
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t max_active_wgs)
: UniversalGemmKernelArgs{host_args.as_ptr,
host_args.bs_ptr,
host_args.ds_ptr,
@@ -135,7 +135,8 @@ struct StreamKKernel
// The workspace pointer is set to nullptr because we must first
// instantiate the TilePartitioner to get the necessary size
workspace_ptr{nullptr},
tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
tile_partitioner{
TilePartitioner{host_args.M, host_args.N, host_args.K, max_active_wgs}}
{
}
@@ -206,9 +207,9 @@ struct StreamKKernel
int num_cu = NumCU(),
int occupancy = Occupancy())
{
const index_t grid = num_cu * occupancy;
const index_t max_active_wgs = num_cu * occupancy;
return StreamKKernelArgs{host_args, grid};
return StreamKKernelArgs{host_args, max_active_wgs};
}
template <bool UseDefaultScheduler = true>
@@ -790,7 +791,7 @@ struct StreamKKernel
// Data-parallel section
for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
tile_idx += kargs.tile_partitioner.get_grid())
tile_idx += kargs.tile_partitioner.get_max_active_wgs())
{
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
block_sync_lds();

View File

@@ -31,7 +31,7 @@ struct StreamKTilePartitionerBase
? memory_operation_enum::atomic_add
: memory_operation_enum::set;
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t max_number_wgs);
/**
* @brief Calculates the total space needed for the partials buffer.
@@ -156,7 +156,7 @@ struct StreamKTilePartitionerBase
* @brief Returns the maximum number of active workgroups; this is assumed to be number of CUs *
* occupancy.
*/
CK_TILE_HOST_DEVICE index_t get_grid() const noexcept;
CK_TILE_HOST_DEVICE index_t get_max_active_wgs() const noexcept;
/**
* @brief Returns the number of tiles in the C tensor that will use the data-parallel (DP)
@@ -215,7 +215,7 @@ struct StreamKTilePartitionerBase
protected:
index_t num_tiles_;
index_t grid_;
index_t max_active_wgs_;
index_t dp_tiles_;
private:
@@ -270,7 +270,7 @@ struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>
StreamKTilePartitioner(ck_tile::index_t m,
ck_tile::index_t n,
ck_tile::index_t k,
ck_tile::index_t grid);
ck_tile::index_t max_active_wgs);
public:
static constexpr bool PERSISTENT = true;
@@ -290,7 +290,7 @@ struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>
/**
* @brief Returns the total number of DP tiles left over when `dp_tiles_` is not evenly
* divisible by `grid_`.
* divisible by `max_active_wgs_`.
*/
CK_TILE_HOST_DEVICE index_t get_extra_dp_tiles() const noexcept;
@@ -317,7 +317,7 @@ struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>
StreamKTilePartitioner(ck_tile::index_t m,
ck_tile::index_t n,
ck_tile::index_t k,
ck_tile::index_t grid);
ck_tile::index_t max_number_wgs);
public:
static constexpr bool PERSISTENT = false;

View File

@@ -7,24 +7,25 @@ namespace ck_tile {
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::StreamKTilePartitionerBase(
index_t m, index_t n, index_t k, index_t grid)
: grid_{grid}, n_{n}
index_t m, index_t n, index_t k, index_t max_active_wgs)
: max_active_wgs_{max_active_wgs}, n_{n}
{
iters_per_tile_ = integer_divide_ceil(k, KPerBlock);
num_tiles_ = integer_divide_ceil(m, MPerBlock) * integer_divide_ceil(n_, NPerBlock);
bool big_enough = num_tiles_ > grid_;
index_t remainder_tiles = num_tiles_ % grid_;
bool big_enough = num_tiles_ > max_active_wgs_;
index_t remainder_tiles = num_tiles_ % max_active_wgs_;
if(remainder_tiles)
{
sk_tiles_ = big_enough ? full_tiles_ * grid_ + (num_tiles_ % grid_) : num_tiles_;
sk_tiles_ = min(num_tiles_, sk_tiles_);
sk_ctas_ = grid_;
sk_tiles_ = big_enough ? full_tiles_ * max_active_wgs_ + (num_tiles_ % max_active_wgs_)
: num_tiles_;
sk_tiles_ = min(num_tiles_, sk_tiles_);
sk_ctas_ = max_active_wgs_;
total_sk_iters_ = sk_tiles_ * iters_per_tile_;
// If there still isn't enough work to saturate all CUs, then just revert to DP only.
if(total_sk_iters_ < grid_)
if(total_sk_iters_ < max_active_wgs_)
{
sk_tiles_ = 0;
sk_ctas_ = 0;
@@ -175,9 +176,10 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_num_t
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_grid() const noexcept
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_max_active_wgs()
const noexcept
{
return grid_;
return max_active_wgs_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
@@ -287,11 +289,11 @@ struct StreamKTilePartitioner;
// child class for Persistent Tile Partitioner
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::StreamKTilePartitioner(
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t max_active_wgs)
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, max_active_wgs)
{ // inherit from base constructor
dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_;
extra_dp_tiles_ = this->dp_tiles_ % this->grid_;
dp_tiles_per_cta_ = this->dp_tiles_ / this->max_active_wgs_;
extra_dp_tiles_ = this->dp_tiles_ % this->max_active_wgs_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
@@ -301,7 +303,7 @@ StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::grid_si
{
if(extra_dp_tiles_ == 0)
{
return dim3(this->grid_, 1, 1);
return dim3(this->max_active_wgs_, 1, 1);
}
else
{
@@ -328,8 +330,8 @@ StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::get_ext
// child class for Non-Persistent Tile Partitioner
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>::StreamKTilePartitioner(
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t max_active_wgs)
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, max_active_wgs)
{ // inherit from base constructor
dp_ctas_ = this->dp_tiles_;
dp_start_block_idx_ = 0;

View File

@@ -8,10 +8,10 @@ TEST(StreamKTilePartitionerBaseConstructor, SKOnly)
using Config = StreamKTilePartitionerBaseConfigSKOnly;
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
StreamKTilePartitionerBaseExpected expected_values{
2, 0, 3, 4, 1, 2, 1, 0, 2, Config::GRID, Config::N};
2, 0, 3, 4, 1, 2, 1, 0, 2, Config::MAX_ACTIVE_WGS, Config::N};
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
}
@@ -20,10 +20,10 @@ TEST(StreamKTilePartitionerBaseConstructor, DPOnly)
using Config = StreamKTilePartitionerBaseConfigDPOnly;
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
StreamKTilePartitionerBaseExpected expected_values{
0, 6, 0, 0, 0, 2, 0, 12, 6, Config::GRID, Config::N};
0, 6, 0, 0, 0, 2, 0, 12, 6, Config::MAX_ACTIVE_WGS, Config::N};
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
}
@@ -32,10 +32,10 @@ TEST(StreamKTilePartitionerBaseConstructor, DP2TileSK)
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
StreamKTilePartitionerBaseExpected expected_values{
4, 3, 3, 8, 2, 2, 2, 6, 7, Config::GRID, Config::N};
4, 3, 3, 8, 2, 2, 2, 6, 7, Config::MAX_ACTIVE_WGS, Config::N};
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
}
@@ -44,10 +44,10 @@ TEST(StreamKTilePartitionerBaseConstructor, EdgeCase)
using Config = StreamKTilePartitionerBaseConfigEdgeCase;
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
StreamKTilePartitionerBaseExpected expected_values{
0, 1, 0, 0, 0, 2, 0, 2, 1, Config::GRID, Config::N};
0, 1, 0, 0, 0, 2, 0, 2, 1, Config::MAX_ACTIVE_WGS, Config::N};
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
}
@@ -57,7 +57,7 @@ TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsLessThan128Bytes)
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
ck_tile::StreamKReductionStrategy::Linear>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128);
}
@@ -68,7 +68,7 @@ TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsEqual128Bytes)
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
ck_tile::StreamKReductionStrategy::Linear>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128);
}
@@ -79,7 +79,7 @@ TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsGreaterThan128Bytes)
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
ck_tile::StreamKReductionStrategy::Linear>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 256);
}
@@ -89,7 +89,7 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, AtomicStrategy)
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)), 0);
}
@@ -100,12 +100,12 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy)
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
ck_tile::StreamKReductionStrategy::Linear>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
ck_tile::index_t expected_partials_size =
sizeof(float) * Config::M_TILE * Config::N_TILE * Config::GRID;
// Since GRID is 3, the final padded flags array must be 128B to ensure the total byte size of
// the flags array is 128B-aligned.
sizeof(float) * Config::M_TILE * Config::N_TILE * Config::MAX_ACTIVE_WGS;
// Since MAX_ACTIVE_WGS is 3, the final padded flags array must be 128B to ensure the total byte
// size of the flags array is 128B-aligned.
ck_tile::index_t expected_flags_size = 128;
EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)),
@@ -117,7 +117,7 @@ TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileLower
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 2);
}
@@ -127,7 +127,7 @@ TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileEqual
using Config = StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile;
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 2);
}
@@ -232,7 +232,7 @@ TEST(StreamKTilePartitionerBaseGetTileBoundaries, GetTileBoundaries)
// Test parameters
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
ck_tile::DeviceMem tile_iter_start_dev(sizeof(ck_tile::index_t));
ck_tile::DeviceMem tile_iter_end_dev(sizeof(ck_tile::index_t));
ck_tile::index_t tile_idx = 1;
@@ -267,7 +267,7 @@ TEST(StreamKTilePartitionerBaseGetTileIndex, GetTileIndex)
// Test parameters
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
ck_tile::DeviceMem tile_idx_dev(sizeof(ck_tile::index_t));
ck_tile::index_t iter_start = 8;
@@ -299,7 +299,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, ZeroExtraItersBeforeMe)
// Test parameters
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t));
ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t));
ck_tile::index_t cta_idx = 0;
@@ -333,7 +333,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, NonZeroExtraItersBeforeMe)
// Test parameters
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t));
ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t));
ck_tile::index_t cta_idx = 1;
@@ -367,7 +367,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, MinIsExtraIters)
// Test parameters
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t));
ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t));
ck_tile::index_t cta_idx = 2;
@@ -493,7 +493,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, SKOnly)
ck_tile::
StreamKTilePartitioner<Config::GemmShape, ck_tile::StreamKReductionStrategy::Atomic, true>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
StreamKTilePartitionerV2PersistentExpected expected_values{0, 0, 3};
validate_streamk_persistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -506,7 +506,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, DPOnly)
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
true>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
StreamKTilePartitionerV2PersistentExpected expected_values{2, 0, 3};
validate_streamk_persistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -519,7 +519,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, DP2TileSK)
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
true>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
StreamKTilePartitionerV2PersistentExpected expected_values{1, 0, 3};
validate_streamk_persistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -532,7 +532,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, EdgeCase)
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
true>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
StreamKTilePartitionerV2PersistentExpected expected_values{0, 1, 4};
validate_streamk_persistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -545,10 +545,10 @@ TEST(StreamKTilePartitioner_GridSize_Persistent, SKOnly)
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
true>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
const auto g = tile_partitioner.grid_size();
EXPECT_EQ(g.x, Config::GRID);
EXPECT_EQ(g.x, Config::MAX_ACTIVE_WGS);
}
TEST(StreamKTilePartitioner_GridSize_Persistent, EdgeCase)
@@ -558,7 +558,7 @@ TEST(StreamKTilePartitioner_GridSize_Persistent, EdgeCase)
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
true>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
const auto g = tile_partitioner.grid_size();
EXPECT_EQ(g.x, 1);
@@ -571,7 +571,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, SKOnly)
ck_tile::
StreamKTilePartitioner<Config::GemmShape, ck_tile::StreamKReductionStrategy::Atomic, false>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
StreamKTilePartitionerV2NonPersistentExpected expected_values{0, 0, 0, 3};
validate_streamk_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -584,7 +584,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, DPOnly)
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
false>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
StreamKTilePartitionerV2NonPersistentExpected expected_values{6, 0, 6, 3};
validate_streamk_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -597,7 +597,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, DP2TileSK)
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
false>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
StreamKTilePartitionerV2NonPersistentExpected expected_values{3, 0, 3, 3};
validate_streamk_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -610,7 +610,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, EdgeCase)
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
false>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
StreamKTilePartitionerV2NonPersistentExpected expected_values{1, 0, 1, 4};
validate_streamk_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
@@ -623,7 +623,7 @@ TEST(StreamKTilePartitioner_GridSize_NonPersistent, DP2TileSK)
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
false>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
const auto g = tile_partitioner.grid_size();
EXPECT_EQ(g.x, 6);

View File

@@ -165,7 +165,7 @@ struct StreamKTilePartitionerBaseExpected
ck_tile::index_t extra_iters_;
ck_tile::index_t total_dp_iters_;
ck_tile::index_t num_tiles_;
ck_tile::index_t grid_;
ck_tile::index_t max_active_wgs_;
ck_tile::index_t n_;
};
@@ -183,7 +183,7 @@ void validate_streamk_base_constructor(
EXPECT_EQ(tile_partitioner.get_iters_per_tile(), expected_values.iters_per_tile_);
EXPECT_EQ(tile_partitioner.get_total_dp_iters(), expected_values.total_dp_iters_);
EXPECT_EQ(tile_partitioner.get_num_tiles(), expected_values.num_tiles_);
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
EXPECT_EQ(tile_partitioner.get_max_active_wgs(), expected_values.max_active_wgs_);
EXPECT_EQ(tile_partitioner.get_n(), expected_values.n_);
}
@@ -201,9 +201,9 @@ struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitioner
static constexpr ck_tile::index_t M = 28;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 16;
// The minimum number of bytes needed for the flags array is GRID * 4B = 3 * 4B = 12B. To ensure
// the total byte size of the array is 128B-aligned, the flags array must be 128B.
static constexpr ck_tile::index_t GRID = 3;
// The minimum number of bytes needed for the flags array is MAX_ACTIVE_WGS * 4B = 3 * 4B = 12B.
// To ensure the total byte size of the array is 128B-aligned, the flags array must be 128B.
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 3;
static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 4;
@@ -220,9 +220,9 @@ struct StreamKTilePartitionerBaseConfigFlagsSizeEqual128Bytes
static constexpr ck_tile::index_t M = 28;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 32;
// The minimum number of bytes needed for the flags array is GRID * 4B = 32 * 4B = 128B. So, the
// number of bytes for the flags array should be 128B.
static constexpr ck_tile::index_t GRID = 32;
// The minimum number of bytes needed for the flags array is MAX_ACTIVE_WGS * 4B = 32 * 4B =
// 128B. So, the number of bytes for the flags array should be 128B.
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 32;
static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 4;
@@ -239,10 +239,10 @@ struct StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes
static constexpr ck_tile::index_t M = 28;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 33;
// The minimum number of bytes needed for the flags array is GRID * 4B = 33 * 4B = 132B. So, the
// number of bytes for the flags array should be 2 * 128B = 256B to ensure the total byte size
// of the array is 128B-aligned.
static constexpr ck_tile::index_t GRID = 33;
// The minimum number of bytes needed for the flags array is MAX_ACTIVE_WGS * 4B = 33 * 4B =
// 132B. So, the number of bytes for the flags array should be 2 * 128B = 256B to ensure the
// total byte size of the array is 128B-aligned.
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 33;
static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 4;
@@ -256,10 +256,10 @@ struct StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes
struct StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile
: public StreamKTilePartitionerBaseConfig
{
static constexpr ck_tile::index_t M = 16;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 16;
static constexpr ck_tile::index_t GRID = 8;
static constexpr ck_tile::index_t M = 16;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 16;
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 8;
static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 4;
@@ -272,10 +272,10 @@ struct StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile
struct StreamKTilePartitionerBaseConfigDPOnly : public StreamKTilePartitionerBaseConfig
{
static constexpr ck_tile::index_t M = 12;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 16;
static constexpr ck_tile::index_t GRID = 3;
static constexpr ck_tile::index_t M = 12;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 16;
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 3;
static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 2;
@@ -288,10 +288,10 @@ struct StreamKTilePartitionerBaseConfigDPOnly : public StreamKTilePartitionerBas
struct StreamKTilePartitionerBaseConfigSKOnly : public StreamKTilePartitionerBaseConfig
{
static constexpr ck_tile::index_t M = 4;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 16;
static constexpr ck_tile::index_t GRID = 3;
static constexpr ck_tile::index_t M = 4;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 16;
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 3;
static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 2;
@@ -304,10 +304,10 @@ struct StreamKTilePartitionerBaseConfigSKOnly : public StreamKTilePartitionerBas
struct StreamKTilePartitionerBaseConfigSKOnlyLargeK : public StreamKTilePartitionerBaseConfig
{
static constexpr ck_tile::index_t M = 8;
static constexpr ck_tile::index_t N = 2;
static constexpr ck_tile::index_t K = 10;
static constexpr ck_tile::index_t GRID = 5;
static constexpr ck_tile::index_t M = 8;
static constexpr ck_tile::index_t N = 2;
static constexpr ck_tile::index_t K = 10;
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 5;
static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 2;
@@ -321,10 +321,10 @@ struct StreamKTilePartitionerBaseConfigSKOnlyLargeK : public StreamKTilePartitio
struct StreamKTilePartitionerBaseConfigEdgeCase : public StreamKTilePartitionerBaseConfig
{
static constexpr ck_tile::index_t M = 4;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 16;
static constexpr ck_tile::index_t GRID = 4;
static constexpr ck_tile::index_t M = 4;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 16;
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 4;
static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 4;
@@ -340,10 +340,10 @@ struct StreamKTilePartitionerBaseConfigLargerCTensor : public StreamKTilePartiti
// This config has 3 macro tiles in the M dimension and 4 macro tiles in the N dimension.
// This facilitates testing the get_output_tile_index method.
static constexpr ck_tile::index_t M = 12;
static constexpr ck_tile::index_t N = 16;
static constexpr ck_tile::index_t K = 16;
static constexpr ck_tile::index_t GRID = 4;
static constexpr ck_tile::index_t M = 12;
static constexpr ck_tile::index_t N = 16;
static constexpr ck_tile::index_t K = 16;
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 4;
static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 4;
@@ -366,7 +366,7 @@ void test_get_output_tile_index(ck_tile::index_t tile_idx,
// Test parameters
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
ck_tile::DeviceMem im_dev(sizeof(ck_tile::index_t));
ck_tile::DeviceMem in_dev(sizeof(ck_tile::index_t));
@@ -402,7 +402,7 @@ void test_get_tile_local_cta_idx(ck_tile::index_t tile_iter_start,
// Test parameters
ck_tile::StreamKTilePartitionerBase<typename Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
ck_tile::DeviceMem tile_local_cta_idx_dev(sizeof(ck_tile::index_t));
// Launch kernel
@@ -426,7 +426,7 @@ struct StreamKTilePartitionerV2PersistentExpected
{
ck_tile::index_t dp_tiles_per_cta_;
ck_tile::index_t extra_dp_tiles_;
ck_tile::index_t grid_;
ck_tile::index_t max_active_wgs_;
};
struct StreamKTilePartitionerV2NonPersistentExpected
@@ -434,7 +434,7 @@ struct StreamKTilePartitionerV2NonPersistentExpected
ck_tile::index_t dp_ctas_;
ck_tile::index_t dp_start_block_idx_;
ck_tile::index_t sk_start_block_idx_;
ck_tile::index_t grid_;
ck_tile::index_t max_active_wgs_;
};
// Persistent
@@ -446,7 +446,7 @@ void validate_streamk_persistent(
{
EXPECT_EQ(tile_partitioner.get_dp_tiles_per_cta(), expected_values.dp_tiles_per_cta_);
EXPECT_EQ(tile_partitioner.get_extra_dp_tiles(), expected_values.extra_dp_tiles_);
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
EXPECT_EQ(tile_partitioner.get_max_active_wgs(), expected_values.max_active_wgs_);
}
// Non-Persistent
@@ -459,5 +459,5 @@ void validate_streamk_nonpersistent(
EXPECT_EQ(tile_partitioner.get_dp_ctas(), expected_values.dp_ctas_);
EXPECT_EQ(tile_partitioner.get_dp_start_block_idx(), expected_values.dp_start_block_idx_);
EXPECT_EQ(tile_partitioner.get_sk_start_block_idx(), expected_values.sk_start_block_idx_);
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
EXPECT_EQ(tile_partitioner.get_max_active_wgs(), expected_values.max_active_wgs_);
}