mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#4795 (commit 6590a1a)
[CK_TILE] Rename Stream-K grid function ## Motivation This PR introduces a change in the name of the get_grid function in the Stream-K TilePartitioner to avoid confusion with a similarly named method. In the Stream-K TilePartitioner, there is get_grid() which returns num_cu*occupancy and there is grid_size() which returns the grid size used to launch the kernel. In this PR, we change get_grid() to be get_max_active_wgs() to better reflect what the function returns and not confuse it with grid_size(). ## Technical Details Initially in the Stream-K TilePartitioner we had get_grid() which returned grid_. We are renaming get_grid() to get_max_active_wgs() and grid_ to max_active_wgs_ internally, while keeping grid_size() the same. The parameter, grid, for the Stream-K TilePartitioner remains the same to maintain consistency with the rest of the Stream-K API. ## Test Plan Validated using the test suite that is already present. ## Test Result All tests passed ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
a268a2a2e1
commit
da863dae1b
@@ -119,7 +119,7 @@ struct StreamKKernel
|
||||
|
||||
struct StreamKKernelArgs : ck_tile::UniversalGemmKernelArgs<>
|
||||
{
|
||||
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid)
|
||||
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t max_active_wgs)
|
||||
: UniversalGemmKernelArgs{host_args.as_ptr,
|
||||
host_args.bs_ptr,
|
||||
host_args.ds_ptr,
|
||||
@@ -135,7 +135,8 @@ struct StreamKKernel
|
||||
// The workspace pointer is set to nullptr because we must first
|
||||
// instantiate the TilePartitioner to get the necessary size
|
||||
workspace_ptr{nullptr},
|
||||
tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
|
||||
tile_partitioner{
|
||||
TilePartitioner{host_args.M, host_args.N, host_args.K, max_active_wgs}}
|
||||
|
||||
{
|
||||
}
|
||||
@@ -206,9 +207,9 @@ struct StreamKKernel
|
||||
int num_cu = NumCU(),
|
||||
int occupancy = Occupancy())
|
||||
{
|
||||
const index_t grid = num_cu * occupancy;
|
||||
const index_t max_active_wgs = num_cu * occupancy;
|
||||
|
||||
return StreamKKernelArgs{host_args, grid};
|
||||
return StreamKKernelArgs{host_args, max_active_wgs};
|
||||
}
|
||||
|
||||
template <bool UseDefaultScheduler = true>
|
||||
@@ -790,7 +791,7 @@ struct StreamKKernel
|
||||
|
||||
// Data-parallel section
|
||||
for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
|
||||
tile_idx += kargs.tile_partitioner.get_grid())
|
||||
tile_idx += kargs.tile_partitioner.get_max_active_wgs())
|
||||
{
|
||||
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
|
||||
block_sync_lds();
|
||||
|
||||
@@ -31,7 +31,7 @@ struct StreamKTilePartitionerBase
|
||||
? memory_operation_enum::atomic_add
|
||||
: memory_operation_enum::set;
|
||||
|
||||
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);
|
||||
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t max_number_wgs);
|
||||
|
||||
/**
|
||||
* @brief Calculates the total space needed for the partials buffer.
|
||||
@@ -156,7 +156,7 @@ struct StreamKTilePartitionerBase
|
||||
* @brief Returns the maximum number of active workgroups; this is assumed to be number of CUs *
|
||||
* occupancy.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_grid() const noexcept;
|
||||
CK_TILE_HOST_DEVICE index_t get_max_active_wgs() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the number of tiles in the C tensor that will use the data-parallel (DP)
|
||||
@@ -215,7 +215,7 @@ struct StreamKTilePartitionerBase
|
||||
|
||||
protected:
|
||||
index_t num_tiles_;
|
||||
index_t grid_;
|
||||
index_t max_active_wgs_;
|
||||
index_t dp_tiles_;
|
||||
|
||||
private:
|
||||
@@ -270,7 +270,7 @@ struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>
|
||||
StreamKTilePartitioner(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t grid);
|
||||
ck_tile::index_t max_active_wgs);
|
||||
|
||||
public:
|
||||
static constexpr bool PERSISTENT = true;
|
||||
@@ -290,7 +290,7 @@ struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>
|
||||
|
||||
/**
|
||||
* @brief Returns the total number of DP tiles left over when `dp_tiles_` is not evenly
|
||||
* divisible by `grid_`.
|
||||
* divisible by `max_active_wgs_`.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_extra_dp_tiles() const noexcept;
|
||||
|
||||
@@ -317,7 +317,7 @@ struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>
|
||||
StreamKTilePartitioner(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t grid);
|
||||
ck_tile::index_t max_number_wgs);
|
||||
|
||||
public:
|
||||
static constexpr bool PERSISTENT = false;
|
||||
|
||||
@@ -7,24 +7,25 @@ namespace ck_tile {
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::StreamKTilePartitionerBase(
|
||||
index_t m, index_t n, index_t k, index_t grid)
|
||||
: grid_{grid}, n_{n}
|
||||
index_t m, index_t n, index_t k, index_t max_active_wgs)
|
||||
: max_active_wgs_{max_active_wgs}, n_{n}
|
||||
{
|
||||
iters_per_tile_ = integer_divide_ceil(k, KPerBlock);
|
||||
num_tiles_ = integer_divide_ceil(m, MPerBlock) * integer_divide_ceil(n_, NPerBlock);
|
||||
|
||||
bool big_enough = num_tiles_ > grid_;
|
||||
index_t remainder_tiles = num_tiles_ % grid_;
|
||||
bool big_enough = num_tiles_ > max_active_wgs_;
|
||||
index_t remainder_tiles = num_tiles_ % max_active_wgs_;
|
||||
|
||||
if(remainder_tiles)
|
||||
{
|
||||
sk_tiles_ = big_enough ? full_tiles_ * grid_ + (num_tiles_ % grid_) : num_tiles_;
|
||||
sk_tiles_ = min(num_tiles_, sk_tiles_);
|
||||
sk_ctas_ = grid_;
|
||||
sk_tiles_ = big_enough ? full_tiles_ * max_active_wgs_ + (num_tiles_ % max_active_wgs_)
|
||||
: num_tiles_;
|
||||
sk_tiles_ = min(num_tiles_, sk_tiles_);
|
||||
sk_ctas_ = max_active_wgs_;
|
||||
total_sk_iters_ = sk_tiles_ * iters_per_tile_;
|
||||
|
||||
// If there still isn't enough work to saturate all CUs, then just revert to DP only.
|
||||
if(total_sk_iters_ < grid_)
|
||||
if(total_sk_iters_ < max_active_wgs_)
|
||||
{
|
||||
sk_tiles_ = 0;
|
||||
sk_ctas_ = 0;
|
||||
@@ -175,9 +176,10 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_num_t
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_grid() const noexcept
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_max_active_wgs()
|
||||
const noexcept
|
||||
{
|
||||
return grid_;
|
||||
return max_active_wgs_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
@@ -287,11 +289,11 @@ struct StreamKTilePartitioner;
|
||||
// child class for Persistent Tile Partitioner
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::StreamKTilePartitioner(
|
||||
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
|
||||
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t max_active_wgs)
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, max_active_wgs)
|
||||
{ // inherit from base constructor
|
||||
dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_;
|
||||
extra_dp_tiles_ = this->dp_tiles_ % this->grid_;
|
||||
dp_tiles_per_cta_ = this->dp_tiles_ / this->max_active_wgs_;
|
||||
extra_dp_tiles_ = this->dp_tiles_ % this->max_active_wgs_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
@@ -301,7 +303,7 @@ StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::grid_si
|
||||
{
|
||||
if(extra_dp_tiles_ == 0)
|
||||
{
|
||||
return dim3(this->grid_, 1, 1);
|
||||
return dim3(this->max_active_wgs_, 1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -328,8 +330,8 @@ StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::get_ext
|
||||
// child class for Non-Persistent Tile Partitioner
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>::StreamKTilePartitioner(
|
||||
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
|
||||
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t max_active_wgs)
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, max_active_wgs)
|
||||
{ // inherit from base constructor
|
||||
dp_ctas_ = this->dp_tiles_;
|
||||
dp_start_block_idx_ = 0;
|
||||
|
||||
@@ -8,10 +8,10 @@ TEST(StreamKTilePartitionerBaseConstructor, SKOnly)
|
||||
using Config = StreamKTilePartitionerBaseConfigSKOnly;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
StreamKTilePartitionerBaseExpected expected_values{
|
||||
2, 0, 3, 4, 1, 2, 1, 0, 2, Config::GRID, Config::N};
|
||||
2, 0, 3, 4, 1, 2, 1, 0, 2, Config::MAX_ACTIVE_WGS, Config::N};
|
||||
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
@@ -20,10 +20,10 @@ TEST(StreamKTilePartitionerBaseConstructor, DPOnly)
|
||||
using Config = StreamKTilePartitionerBaseConfigDPOnly;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
StreamKTilePartitionerBaseExpected expected_values{
|
||||
0, 6, 0, 0, 0, 2, 0, 12, 6, Config::GRID, Config::N};
|
||||
0, 6, 0, 0, 0, 2, 0, 12, 6, Config::MAX_ACTIVE_WGS, Config::N};
|
||||
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
@@ -32,10 +32,10 @@ TEST(StreamKTilePartitionerBaseConstructor, DP2TileSK)
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
StreamKTilePartitionerBaseExpected expected_values{
|
||||
4, 3, 3, 8, 2, 2, 2, 6, 7, Config::GRID, Config::N};
|
||||
4, 3, 3, 8, 2, 2, 2, 6, 7, Config::MAX_ACTIVE_WGS, Config::N};
|
||||
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
@@ -44,10 +44,10 @@ TEST(StreamKTilePartitionerBaseConstructor, EdgeCase)
|
||||
using Config = StreamKTilePartitionerBaseConfigEdgeCase;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
StreamKTilePartitionerBaseExpected expected_values{
|
||||
0, 1, 0, 0, 0, 2, 0, 2, 1, Config::GRID, Config::N};
|
||||
0, 1, 0, 0, 0, 2, 0, 2, 1, Config::MAX_ACTIVE_WGS, Config::N};
|
||||
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
@@ -57,7 +57,7 @@ TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsLessThan128Bytes)
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Linear>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128);
|
||||
}
|
||||
@@ -68,7 +68,7 @@ TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsEqual128Bytes)
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Linear>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 128);
|
||||
}
|
||||
@@ -79,7 +79,7 @@ TEST(StreamKTilePartitionerBaseGetFlagsBufferSize, FlagsGreaterThan128Bytes)
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Linear>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
EXPECT_EQ(tile_partitioner.get_flags_buffer_size(), 256);
|
||||
}
|
||||
@@ -89,7 +89,7 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, AtomicStrategy)
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)), 0);
|
||||
}
|
||||
@@ -100,12 +100,12 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy)
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Linear>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
ck_tile::index_t expected_partials_size =
|
||||
sizeof(float) * Config::M_TILE * Config::N_TILE * Config::GRID;
|
||||
// Since GRID is 3, the final padded flags array must be 128B to ensure the total byte size of
|
||||
// the flags array is 128B-aligned.
|
||||
sizeof(float) * Config::M_TILE * Config::N_TILE * Config::MAX_ACTIVE_WGS;
|
||||
// Since MAX_ACTIVE_WGS is 3, the final padded flags array must be 128B to ensure the total byte
|
||||
// size of the flags array is 128B-aligned.
|
||||
ck_tile::index_t expected_flags_size = 128;
|
||||
|
||||
EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)),
|
||||
@@ -117,7 +117,7 @@ TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileLower
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 2);
|
||||
}
|
||||
@@ -127,7 +127,7 @@ TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileEqual
|
||||
using Config = StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 2);
|
||||
}
|
||||
@@ -232,7 +232,7 @@ TEST(StreamKTilePartitionerBaseGetTileBoundaries, GetTileBoundaries)
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
ck_tile::DeviceMem tile_iter_start_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem tile_iter_end_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t tile_idx = 1;
|
||||
@@ -267,7 +267,7 @@ TEST(StreamKTilePartitionerBaseGetTileIndex, GetTileIndex)
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
ck_tile::DeviceMem tile_idx_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t iter_start = 8;
|
||||
|
||||
@@ -299,7 +299,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, ZeroExtraItersBeforeMe)
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t cta_idx = 0;
|
||||
@@ -333,7 +333,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, NonZeroExtraItersBeforeMe)
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t cta_idx = 1;
|
||||
@@ -367,7 +367,7 @@ TEST(StreamKTilePartitionerBaseGetIterBoundaries, MinIsExtraIters)
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t cta_idx = 2;
|
||||
@@ -493,7 +493,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, SKOnly)
|
||||
|
||||
ck_tile::
|
||||
StreamKTilePartitioner<Config::GemmShape, ck_tile::StreamKReductionStrategy::Atomic, true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
StreamKTilePartitionerV2PersistentExpected expected_values{0, 0, 3};
|
||||
validate_streamk_persistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
@@ -506,7 +506,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, DPOnly)
|
||||
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
StreamKTilePartitionerV2PersistentExpected expected_values{2, 0, 3};
|
||||
validate_streamk_persistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
@@ -519,7 +519,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, DP2TileSK)
|
||||
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
StreamKTilePartitionerV2PersistentExpected expected_values{1, 0, 3};
|
||||
validate_streamk_persistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
@@ -532,7 +532,7 @@ TEST(StreamKTilePartitioner_PersistentConstructor, EdgeCase)
|
||||
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
StreamKTilePartitionerV2PersistentExpected expected_values{0, 1, 4};
|
||||
validate_streamk_persistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
@@ -545,10 +545,10 @@ TEST(StreamKTilePartitioner_GridSize_Persistent, SKOnly)
|
||||
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
const auto g = tile_partitioner.grid_size();
|
||||
EXPECT_EQ(g.x, Config::GRID);
|
||||
EXPECT_EQ(g.x, Config::MAX_ACTIVE_WGS);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_GridSize_Persistent, EdgeCase)
|
||||
@@ -558,7 +558,7 @@ TEST(StreamKTilePartitioner_GridSize_Persistent, EdgeCase)
|
||||
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
const auto g = tile_partitioner.grid_size();
|
||||
EXPECT_EQ(g.x, 1);
|
||||
@@ -571,7 +571,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, SKOnly)
|
||||
|
||||
ck_tile::
|
||||
StreamKTilePartitioner<Config::GemmShape, ck_tile::StreamKReductionStrategy::Atomic, false>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
StreamKTilePartitionerV2NonPersistentExpected expected_values{0, 0, 0, 3};
|
||||
validate_streamk_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
@@ -584,7 +584,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, DPOnly)
|
||||
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
false>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
StreamKTilePartitionerV2NonPersistentExpected expected_values{6, 0, 6, 3};
|
||||
validate_streamk_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
@@ -597,7 +597,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, DP2TileSK)
|
||||
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
false>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
StreamKTilePartitionerV2NonPersistentExpected expected_values{3, 0, 3, 3};
|
||||
validate_streamk_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
@@ -610,7 +610,7 @@ TEST(StreamKTilePartitioner_NonPersistentConstructor, EdgeCase)
|
||||
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
false>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
StreamKTilePartitionerV2NonPersistentExpected expected_values{1, 0, 1, 4};
|
||||
validate_streamk_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
@@ -623,7 +623,7 @@ TEST(StreamKTilePartitioner_GridSize_NonPersistent, DP2TileSK)
|
||||
ck_tile::StreamKTilePartitioner<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
false>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
|
||||
const auto g = tile_partitioner.grid_size();
|
||||
EXPECT_EQ(g.x, 6);
|
||||
|
||||
@@ -165,7 +165,7 @@ struct StreamKTilePartitionerBaseExpected
|
||||
ck_tile::index_t extra_iters_;
|
||||
ck_tile::index_t total_dp_iters_;
|
||||
ck_tile::index_t num_tiles_;
|
||||
ck_tile::index_t grid_;
|
||||
ck_tile::index_t max_active_wgs_;
|
||||
ck_tile::index_t n_;
|
||||
};
|
||||
|
||||
@@ -183,7 +183,7 @@ void validate_streamk_base_constructor(
|
||||
EXPECT_EQ(tile_partitioner.get_iters_per_tile(), expected_values.iters_per_tile_);
|
||||
EXPECT_EQ(tile_partitioner.get_total_dp_iters(), expected_values.total_dp_iters_);
|
||||
EXPECT_EQ(tile_partitioner.get_num_tiles(), expected_values.num_tiles_);
|
||||
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
|
||||
EXPECT_EQ(tile_partitioner.get_max_active_wgs(), expected_values.max_active_wgs_);
|
||||
EXPECT_EQ(tile_partitioner.get_n(), expected_values.n_);
|
||||
}
|
||||
|
||||
@@ -201,9 +201,9 @@ struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitioner
|
||||
static constexpr ck_tile::index_t M = 28;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
// The minimum number of bytes needed for the flags array is GRID * 4B = 3 * 4B = 12B. To ensure
|
||||
// the total byte size of the array is 128B-aligned, the flags array must be 128B.
|
||||
static constexpr ck_tile::index_t GRID = 3;
|
||||
// The minimum number of bytes needed for the flags array is MAX_ACTIVE_WGS * 4B = 3 * 4B = 12B.
|
||||
// To ensure the total byte size of the array is 128B-aligned, the flags array must be 128B.
|
||||
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 3;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 4;
|
||||
@@ -220,9 +220,9 @@ struct StreamKTilePartitionerBaseConfigFlagsSizeEqual128Bytes
|
||||
static constexpr ck_tile::index_t M = 28;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 32;
|
||||
// The minimum number of bytes needed for the flags array is GRID * 4B = 32 * 4B = 128B. So, the
|
||||
// number of bytes for the flags array should be 128B.
|
||||
static constexpr ck_tile::index_t GRID = 32;
|
||||
// The minimum number of bytes needed for the flags array is MAX_ACTIVE_WGS * 4B = 32 * 4B =
|
||||
// 128B. So, the number of bytes for the flags array should be 128B.
|
||||
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 4;
|
||||
@@ -239,10 +239,10 @@ struct StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes
|
||||
static constexpr ck_tile::index_t M = 28;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 33;
|
||||
// The minimum number of bytes needed for the flags array is GRID * 4B = 33 * 4B = 132B. So, the
|
||||
// number of bytes for the flags array should be 2 * 128B = 256B to ensure the total byte size
|
||||
// of the array is 128B-aligned.
|
||||
static constexpr ck_tile::index_t GRID = 33;
|
||||
// The minimum number of bytes needed for the flags array is MAX_ACTIVE_WGS * 4B = 33 * 4B =
|
||||
// 132B. So, the number of bytes for the flags array should be 2 * 128B = 256B to ensure the
|
||||
// total byte size of the array is 128B-aligned.
|
||||
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 33;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 4;
|
||||
@@ -256,10 +256,10 @@ struct StreamKTilePartitionerBaseConfigFlagsSizeGreaterThan128Bytes
|
||||
struct StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile
|
||||
: public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 16;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t GRID = 8;
|
||||
static constexpr ck_tile::index_t M = 16;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 8;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 4;
|
||||
@@ -272,10 +272,10 @@ struct StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigDPOnly : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 12;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t GRID = 3;
|
||||
static constexpr ck_tile::index_t M = 12;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 3;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 2;
|
||||
@@ -288,10 +288,10 @@ struct StreamKTilePartitionerBaseConfigDPOnly : public StreamKTilePartitionerBas
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigSKOnly : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 4;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t GRID = 3;
|
||||
static constexpr ck_tile::index_t M = 4;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 3;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 2;
|
||||
@@ -304,10 +304,10 @@ struct StreamKTilePartitionerBaseConfigSKOnly : public StreamKTilePartitionerBas
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigSKOnlyLargeK : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 8;
|
||||
static constexpr ck_tile::index_t N = 2;
|
||||
static constexpr ck_tile::index_t K = 10;
|
||||
static constexpr ck_tile::index_t GRID = 5;
|
||||
static constexpr ck_tile::index_t M = 8;
|
||||
static constexpr ck_tile::index_t N = 2;
|
||||
static constexpr ck_tile::index_t K = 10;
|
||||
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 5;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 2;
|
||||
@@ -321,10 +321,10 @@ struct StreamKTilePartitionerBaseConfigSKOnlyLargeK : public StreamKTilePartitio
|
||||
struct StreamKTilePartitionerBaseConfigEdgeCase : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
|
||||
static constexpr ck_tile::index_t M = 4;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t GRID = 4;
|
||||
static constexpr ck_tile::index_t M = 4;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 4;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 4;
|
||||
@@ -340,10 +340,10 @@ struct StreamKTilePartitionerBaseConfigLargerCTensor : public StreamKTilePartiti
|
||||
// This config has 3 macro tiles in the M dimension and 4 macro tiles in the N dimension.
|
||||
// This facilitates testing the get_output_tile_index method.
|
||||
|
||||
static constexpr ck_tile::index_t M = 12;
|
||||
static constexpr ck_tile::index_t N = 16;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t GRID = 4;
|
||||
static constexpr ck_tile::index_t M = 12;
|
||||
static constexpr ck_tile::index_t N = 16;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t MAX_ACTIVE_WGS = 4;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 4;
|
||||
@@ -366,7 +366,7 @@ void test_get_output_tile_index(ck_tile::index_t tile_idx,
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
ck_tile::DeviceMem im_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem in_dev(sizeof(ck_tile::index_t));
|
||||
|
||||
@@ -402,7 +402,7 @@ void test_get_tile_local_cta_idx(ck_tile::index_t tile_iter_start,
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<typename Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
Config::M, Config::N, Config::K, Config::MAX_ACTIVE_WGS};
|
||||
ck_tile::DeviceMem tile_local_cta_idx_dev(sizeof(ck_tile::index_t));
|
||||
|
||||
// Launch kernel
|
||||
@@ -426,7 +426,7 @@ struct StreamKTilePartitionerV2PersistentExpected
|
||||
{
|
||||
ck_tile::index_t dp_tiles_per_cta_;
|
||||
ck_tile::index_t extra_dp_tiles_;
|
||||
ck_tile::index_t grid_;
|
||||
ck_tile::index_t max_active_wgs_;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerV2NonPersistentExpected
|
||||
@@ -434,7 +434,7 @@ struct StreamKTilePartitionerV2NonPersistentExpected
|
||||
ck_tile::index_t dp_ctas_;
|
||||
ck_tile::index_t dp_start_block_idx_;
|
||||
ck_tile::index_t sk_start_block_idx_;
|
||||
ck_tile::index_t grid_;
|
||||
ck_tile::index_t max_active_wgs_;
|
||||
};
|
||||
|
||||
// Persistent
|
||||
@@ -446,7 +446,7 @@ void validate_streamk_persistent(
|
||||
{
|
||||
EXPECT_EQ(tile_partitioner.get_dp_tiles_per_cta(), expected_values.dp_tiles_per_cta_);
|
||||
EXPECT_EQ(tile_partitioner.get_extra_dp_tiles(), expected_values.extra_dp_tiles_);
|
||||
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
|
||||
EXPECT_EQ(tile_partitioner.get_max_active_wgs(), expected_values.max_active_wgs_);
|
||||
}
|
||||
|
||||
// Non-Persistent
|
||||
@@ -459,5 +459,5 @@ void validate_streamk_nonpersistent(
|
||||
EXPECT_EQ(tile_partitioner.get_dp_ctas(), expected_values.dp_ctas_);
|
||||
EXPECT_EQ(tile_partitioner.get_dp_start_block_idx(), expected_values.dp_start_block_idx_);
|
||||
EXPECT_EQ(tile_partitioner.get_sk_start_block_idx(), expected_values.sk_start_block_idx_);
|
||||
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
|
||||
EXPECT_EQ(tile_partitioner.get_max_active_wgs(), expected_values.max_active_wgs_);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user