mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK-Tile] Remove usage of tile partitioner's full gemm shape (#3204)
gemm shape should be used from the pipeline instead (where it gets from a problem description struct)
This commit is contained in:
@@ -386,11 +386,9 @@ template <typename BlockGemmShapeType,
|
||||
uint32_t TileSwizzleSubM = 8>
|
||||
struct StreamKTilePartitioner
|
||||
{
|
||||
using BlockGemmShape = BlockGemmShapeType;
|
||||
|
||||
static constexpr uint32_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr uint32_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr uint32_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr uint32_t MPerBlock = BlockGemmShapeType::kM;
|
||||
static constexpr uint32_t NPerBlock = BlockGemmShapeType::kN;
|
||||
static constexpr uint32_t KPerBlock = BlockGemmShapeType::kK;
|
||||
|
||||
CK_TILE_HOST_DEVICE StreamKTilePartitioner() noexcept = delete;
|
||||
|
||||
|
||||
@@ -22,11 +22,10 @@ template <typename BlockGemmShapeType,
|
||||
StreamKReductionStrategy ReductionStrategyType = StreamKReductionStrategy::Atomic>
|
||||
struct StreamKTilePartitionerBase
|
||||
{
|
||||
using BlockGemmShape = BlockGemmShapeType;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr index_t MPerBlock = BlockGemmShapeType::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShapeType::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShapeType::kK;
|
||||
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType;
|
||||
|
||||
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);
|
||||
|
||||
@@ -325,7 +325,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
|
||||
{
|
||||
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{});
|
||||
const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
|
||||
const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1);
|
||||
|
||||
@@ -584,7 +584,7 @@ struct UniversalGemmKernel
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
|
||||
const auto& as_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -617,7 +617,7 @@ struct UniversalGemmKernel
|
||||
using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
|
||||
if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = k_size / K1;
|
||||
@@ -649,7 +649,7 @@ struct UniversalGemmKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
if constexpr(GemmPipeline::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = k_size / K1;
|
||||
@@ -675,8 +675,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
index_t kFlatK =
|
||||
GemmPipeline::BlockGemmShape::flatKPerWarp *
|
||||
(k_size /
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
(k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
|
||||
Reference in New Issue
Block a user