Addition of the derived structs for the new Stream-K TilePartitioner

There are 2 derived structs based on whether Stream-K is persistent or not.
If it's persistent that means that both the data parallel and Stream-K sections
are data parallel. If it's non-persistent that means that only the
Stream-K section is persistent, while the data parallel section will have
separate workgroups allocated for it. Both structs will have a template
argument for Persistent.

The 2 derived classes will inherit common variables and functions from the
Stream-K TilePartitioner base class. There are additional variables for the
differing data parallel sections that will be added to each derived class,
that are in charge of the indexing/bookkeeping for the data parallel sections.
The only additional function that will differ between the 2 structs is GridSize(),
as the non-persistent will allocate extra workgroups for data parallel.

Unit tests for the derived structs are included.
This commit is contained in:
Astha
2025-10-06 15:01:10 -04:00
committed by Emily Martins
parent f87f768d16
commit 8f75d7cea6
5 changed files with 387 additions and 0 deletions

View File

@@ -347,3 +347,148 @@ TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings)
}
}
}
// Persistent
TEST(StreamKTilePartitioner_v2_PersistentConstructor, SKOnly)
{
using Config = StreamKTilePartitionerBaseConfigSKOnly;
ck_tile::StreamKTilePartitioner_v2<Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
true>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
StreamKTilePartitionerV2PersistentExpected expected_values{0, 0, 3};
validate_streamk_v2_persistent<Config::GemmShape>(expected_values, tile_partitioner);
}
TEST(StreamKTilePartitioner_v2_PersistentConstructor, DPOnly)
{
using Config = StreamKTilePartitionerBaseConfigDPOnly;
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
true>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
StreamKTilePartitionerV2PersistentExpected expected_values{2, 0, 3};
validate_streamk_v2_persistent<Config::GemmShape>(expected_values, tile_partitioner);
}
TEST(StreamKTilePartitioner_v2_PersistentConstructor, DP2TileSK)
{
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
true>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
StreamKTilePartitionerV2PersistentExpected expected_values{1, 0, 3};
validate_streamk_v2_persistent<Config::GemmShape>(expected_values, tile_partitioner);
}
TEST(StreamKTilePartitioner_v2_PersistentConstructor, EdgeCase)
{
using Config = StreamKTilePartitionerBaseConfigEdgeCase;
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
true>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
StreamKTilePartitionerV2PersistentExpected expected_values{0, 1, 4};
validate_streamk_v2_persistent<Config::GemmShape>(expected_values, tile_partitioner);
}
TEST(StreamKTilePartitioner_v2_GridSize_Persistent, SKOnly)
{
using Config = StreamKTilePartitionerBaseConfigSKOnly;
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
true>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
const auto g = tile_partitioner.grid_size();
EXPECT_EQ(g.x, Config::GRID);
}
TEST(StreamKTilePartitioner_v2_GridSize_Persistent, EdgeCase)
{
using Config = StreamKTilePartitionerBaseConfigEdgeCase;
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
true>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
const auto g = tile_partitioner.grid_size();
EXPECT_EQ(g.x, 1);
}
// Non-Persistent Tests
TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, SKOnly)
{
using Config = StreamKTilePartitionerBaseConfigSKOnly;
ck_tile::StreamKTilePartitioner_v2<Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
false>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
StreamKTilePartitionerV2NonPersistentExpected expected_values{0, 0, 0, 3};
validate_streamk_v2_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
}
TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, DPOnly)
{
using Config = StreamKTilePartitionerBaseConfigDPOnly;
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
false>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
StreamKTilePartitionerV2NonPersistentExpected expected_values{6, 0, 6, 3};
validate_streamk_v2_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
}
TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, DP2TileSK)
{
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
false>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
StreamKTilePartitionerV2NonPersistentExpected expected_values{3, 0, 3, 3};
validate_streamk_v2_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
}
TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, EdgeCase)
{
using Config = StreamKTilePartitionerBaseConfigEdgeCase;
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
false>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
StreamKTilePartitionerV2NonPersistentExpected expected_values{1, 0, 1, 4};
validate_streamk_v2_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
}
TEST(StreamKTilePartitioner_v2_GridSize_NonPersistent, DP2TileSK)
{
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
ck_tile::StreamKReductionStrategy::Atomic,
false>
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
const auto g = tile_partitioner.grid_size();
EXPECT_EQ(g.x, 6);
}

View File

@@ -297,4 +297,45 @@ void test_get_output_tile_index(ck_tile::index_t tile_idx,
in_dev.FromDevice(&in);
EXPECT_EQ(im, im_expected);
EXPECT_EQ(in, in_expected);
};
// Configs for TilePartitioner Child structs
struct StreamKTilePartitionerV2PersistentExpected
{
ck_tile::index_t dp_tiles_per_cta_;
ck_tile::index_t extra_dp_tiles_;
ck_tile::index_t grid_;
};
struct StreamKTilePartitionerV2NonPersistentExpected
{
ck_tile::index_t dp_ctas_;
ck_tile::index_t dp_start_block_idx_;
ck_tile::index_t sk_start_block_idx_;
ck_tile::index_t grid_;
};
// Persistent
template <typename GemmShape>
void validate_streamk_v2_persistent(
StreamKTilePartitionerV2PersistentExpected& expected_values,
ck_tile::StreamKTilePartitioner_v2<GemmShape, ck_tile::StreamKReductionStrategy::Atomic, true>&
tile_partitioner)
{
EXPECT_EQ(tile_partitioner.get_dp_tiles_per_cta(), expected_values.dp_tiles_per_cta_);
EXPECT_EQ(tile_partitioner.get_extra_dp_tiles(), expected_values.extra_dp_tiles_);
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
}
// Non-Persistent
template <typename GemmShape>
void validate_streamk_v2_nonpersistent(
StreamKTilePartitionerV2NonPersistentExpected& expected_values,
ck_tile::StreamKTilePartitioner_v2<GemmShape, ck_tile::StreamKReductionStrategy::Atomic, false>&
tile_partitioner)
{
EXPECT_EQ(tile_partitioner.get_dp_ctas(), expected_values.dp_ctas_);
EXPECT_EQ(tile_partitioner.get_dp_start_block_idx(), expected_values.dp_start_block_idx_);
EXPECT_EQ(tile_partitioner.get_sk_start_block_idx(), expected_values.sk_start_block_idx_);
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
}