mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Addition of the derived structs for the new Stream-K TilePartitioner
There are 2 derived structs based on whether Stream-K is persistent or not.
If it's persistent that means that both the data parallel and Stream-K sections
are data parallel. If it's non-persistent that means that only the
Stream-K section is persistent, while the data parallel section will have
separate workgroups allocated for it. Both structs will have a template
argument for Persistent.
The 2 derived classes will inherit common variables and functions from the
Stream-K TilePartitioner base class. There are additional variables for the
differing data parallel sections that will be added to each derived class,
that are in charge of the indexing/bookkeeping for the data parallel sections.
The only additional function that will differ between the 2 structs is GridSize(),
as the non-persistent will allocate extra workgroups for data parallel.
Unit tests for the derived structs are included.
[ROCm/composable_kernel commit: 8f75d7cea6]
This commit is contained in:
@@ -10,6 +10,8 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include <format>
|
||||
#include <iostream>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -810,4 +812,5 @@ struct StreamKTilePartitioner
|
||||
uint32_t M_, N_, K_;
|
||||
uint32_t num_tile_m_, num_tile_n_, num_tile_k_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -202,6 +202,117 @@ struct StreamKTilePartitionerBase
|
||||
index_t n_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Template for the Stream-K tile partitioner derived struct.
|
||||
*
|
||||
* This partitioner is responsible for mapping workgroups to tiles in the C tensor
|
||||
* for the Stream-K algorithm. This struct is derived from
|
||||
* StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>. Behavior of the
|
||||
* StreamKTilePartitioner based on persistency will be in the template specializations.
|
||||
*
|
||||
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
|
||||
* @tparam ReductionStrategy An enum that defines the reduction strategy for the results in the C
|
||||
* Tensor.
|
||||
* @tparam Persistent A bool that indicates whether to use a Persistent approach
|
||||
*/
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy, bool Persistent>
|
||||
struct StreamKTilePartitioner_v2;
|
||||
|
||||
/**
|
||||
* @brief Persistent Stream-K tile partitioner derived struct.
|
||||
*
|
||||
* This partitioner is responsible for mapping workgroups to tiles in the C tensor
|
||||
* for the Stream-K algorithm when using a Persistent approach where no extra workgroups
|
||||
* are allocated for data parallel.
|
||||
*
|
||||
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
|
||||
* @tparam ReductionStrategy An enum that defines the reduction strategy for the results in the C
|
||||
* Tensor.
|
||||
*/
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, true>
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>
|
||||
{
|
||||
StreamKTilePartitioner_v2(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t grid);
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Calculates the launching grid size for the Stream-K kernel. In the Persistent
|
||||
* case, no extra workgroups are allocated for the data parallel section, making the grid
|
||||
* size num_cu * occupancy.
|
||||
*
|
||||
* @return dim_3 The launching grid size for the kernel.
|
||||
*/
|
||||
CK_TILE_HOST auto grid_size() const noexcept -> dim3;
|
||||
|
||||
CK_TILE_HOST_DEVICE index_t get_dp_tiles_per_cta() const noexcept;
|
||||
CK_TILE_HOST_DEVICE index_t get_extra_dp_tiles() const noexcept;
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief The total number of DP tiles per workgroup.
|
||||
*/
|
||||
int dp_tiles_per_cta_;
|
||||
|
||||
/**
|
||||
* @brief The total number of DP tiles left over when dp_tiles is not evenly divisible by grid.
|
||||
*/
|
||||
int extra_dp_tiles_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Non-Persistent Stream-K tile partitioner derived struct.
|
||||
*
|
||||
* This partitioner is responsible for mapping workgroups to tiles in the C tensor
|
||||
* for the Stream-K algorithm when using a Non-Persistent approach where extra workgroups
|
||||
* are allocated for the data parallel section.
|
||||
*
|
||||
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
|
||||
* @tparam ReductionStrategy An enum that defines the reduction strategy for the results in the C
|
||||
* Tensor.
|
||||
*/
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, false>
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>
|
||||
{
|
||||
StreamKTilePartitioner_v2(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t grid);
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Calculates the launching grid size for the Stream-K kernel. In the Non-Persistent
|
||||
* case, extra workgroups are allocated for the data parallel section, making the grid
|
||||
* size the total number of Stream-K and data parallel workgroups.
|
||||
*
|
||||
* @return dim_3 The launching grid size for the kernel.
|
||||
*/
|
||||
CK_TILE_HOST auto grid_size() const noexcept -> dim3;
|
||||
CK_TILE_HOST_DEVICE index_t get_dp_ctas() const noexcept;
|
||||
CK_TILE_HOST_DEVICE index_t get_dp_start_block_idx() const noexcept;
|
||||
CK_TILE_HOST_DEVICE index_t get_sk_start_block_idx() const noexcept;
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief The total number of DP workgroups.
|
||||
*/
|
||||
int dp_ctas_;
|
||||
|
||||
/**
|
||||
* @brief The index that starts the DP workgroups, always 0 in our implementation.
|
||||
*/
|
||||
int dp_start_block_idx_;
|
||||
|
||||
/**
|
||||
* @brief The index that starts the Stream-K workgroups, set to the number of dp_tiles.
|
||||
*/
|
||||
int sk_start_block_idx_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#include "streamk_gemm_tile_partitioner_impl.hpp"
|
||||
|
||||
@@ -211,4 +211,91 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_n() const
|
||||
return n_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy, bool Persistent>
|
||||
struct StreamKTilePartitioner_v2;
|
||||
|
||||
// child class for Persistent Tile Partitioner
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, true>::StreamKTilePartitioner_v2(
|
||||
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>(m, n, k, grid)
|
||||
{ // inherit from base constructor
|
||||
dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_;
|
||||
extra_dp_tiles_ = this->dp_tiles_ % this->grid_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
CK_TILE_HOST auto
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, true>::grid_size() const noexcept
|
||||
-> dim3
|
||||
{
|
||||
if(extra_dp_tiles_ == 0)
|
||||
{
|
||||
return dim3(this->grid_, 1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(this->num_tiles_, 1, 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, true>::get_dp_tiles_per_cta()
|
||||
const noexcept
|
||||
{
|
||||
return dp_tiles_per_cta_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, true>::get_extra_dp_tiles()
|
||||
const noexcept
|
||||
{
|
||||
return extra_dp_tiles_;
|
||||
}
|
||||
|
||||
// child class for Non-Persistent Tile Partitioner
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, false>::StreamKTilePartitioner_v2(
|
||||
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>(m, n, k, grid)
|
||||
{ // inherit from base constructor
|
||||
dp_ctas_ = this->dp_tiles_;
|
||||
dp_start_block_idx_ = 0;
|
||||
sk_start_block_idx_ = this->dp_tiles_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
CK_TILE_HOST auto
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, false>::grid_size() const noexcept
|
||||
-> dim3
|
||||
{
|
||||
return dim3(dp_ctas_ + this->get_sk_ctas(), 1, 1);
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, false>::get_dp_ctas()
|
||||
const noexcept
|
||||
{
|
||||
return dp_ctas_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, false>::get_dp_start_block_idx()
|
||||
const noexcept
|
||||
{
|
||||
return dp_start_block_idx_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, false>::get_sk_start_block_idx()
|
||||
const noexcept
|
||||
{
|
||||
return sk_start_block_idx_;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -347,3 +347,148 @@ TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Persistent
|
||||
TEST(StreamKTilePartitioner_v2_PersistentConstructor, SKOnly)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigSKOnly;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerV2PersistentExpected expected_values{0, 0, 3};
|
||||
validate_streamk_v2_persistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_PersistentConstructor, DPOnly)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDPOnly;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerV2PersistentExpected expected_values{2, 0, 3};
|
||||
validate_streamk_v2_persistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_PersistentConstructor, DP2TileSK)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerV2PersistentExpected expected_values{1, 0, 3};
|
||||
validate_streamk_v2_persistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_PersistentConstructor, EdgeCase)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigEdgeCase;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerV2PersistentExpected expected_values{0, 1, 4};
|
||||
validate_streamk_v2_persistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_GridSize_Persistent, SKOnly)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigSKOnly;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
const auto g = tile_partitioner.grid_size();
|
||||
EXPECT_EQ(g.x, Config::GRID);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_GridSize_Persistent, EdgeCase)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigEdgeCase;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
const auto g = tile_partitioner.grid_size();
|
||||
EXPECT_EQ(g.x, 1);
|
||||
}
|
||||
|
||||
// Non-Persistent Tests
|
||||
TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, SKOnly)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigSKOnly;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
false>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerV2NonPersistentExpected expected_values{0, 0, 0, 3};
|
||||
validate_streamk_v2_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, DPOnly)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDPOnly;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
false>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerV2NonPersistentExpected expected_values{6, 0, 6, 3};
|
||||
validate_streamk_v2_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, DP2TileSK)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
false>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerV2NonPersistentExpected expected_values{3, 0, 3, 3};
|
||||
validate_streamk_v2_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, EdgeCase)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigEdgeCase;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
false>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerV2NonPersistentExpected expected_values{1, 0, 1, 4};
|
||||
validate_streamk_v2_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_GridSize_NonPersistent, DP2TileSK)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
false>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
const auto g = tile_partitioner.grid_size();
|
||||
EXPECT_EQ(g.x, 6);
|
||||
}
|
||||
|
||||
@@ -297,4 +297,45 @@ void test_get_output_tile_index(ck_tile::index_t tile_idx,
|
||||
in_dev.FromDevice(&in);
|
||||
EXPECT_EQ(im, im_expected);
|
||||
EXPECT_EQ(in, in_expected);
|
||||
};
|
||||
|
||||
// Configs for TilePartitioner Child structs
|
||||
struct StreamKTilePartitionerV2PersistentExpected
|
||||
{
|
||||
ck_tile::index_t dp_tiles_per_cta_;
|
||||
ck_tile::index_t extra_dp_tiles_;
|
||||
ck_tile::index_t grid_;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerV2NonPersistentExpected
|
||||
{
|
||||
ck_tile::index_t dp_ctas_;
|
||||
ck_tile::index_t dp_start_block_idx_;
|
||||
ck_tile::index_t sk_start_block_idx_;
|
||||
ck_tile::index_t grid_;
|
||||
};
|
||||
|
||||
// Persistent
|
||||
template <typename GemmShape>
|
||||
void validate_streamk_v2_persistent(
|
||||
StreamKTilePartitionerV2PersistentExpected& expected_values,
|
||||
ck_tile::StreamKTilePartitioner_v2<GemmShape, ck_tile::StreamKReductionStrategy::Atomic, true>&
|
||||
tile_partitioner)
|
||||
{
|
||||
EXPECT_EQ(tile_partitioner.get_dp_tiles_per_cta(), expected_values.dp_tiles_per_cta_);
|
||||
EXPECT_EQ(tile_partitioner.get_extra_dp_tiles(), expected_values.extra_dp_tiles_);
|
||||
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
|
||||
}
|
||||
|
||||
// Non-Persistent
|
||||
template <typename GemmShape>
|
||||
void validate_streamk_v2_nonpersistent(
|
||||
StreamKTilePartitionerV2NonPersistentExpected& expected_values,
|
||||
ck_tile::StreamKTilePartitioner_v2<GemmShape, ck_tile::StreamKReductionStrategy::Atomic, false>&
|
||||
tile_partitioner)
|
||||
{
|
||||
EXPECT_EQ(tile_partitioner.get_dp_ctas(), expected_values.dp_ctas_);
|
||||
EXPECT_EQ(tile_partitioner.get_dp_start_block_idx(), expected_values.dp_start_block_idx_);
|
||||
EXPECT_EQ(tile_partitioner.get_sk_start_block_idx(), expected_values.sk_start_block_idx_);
|
||||
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user