[CK_TILE] Stream-K Tree Reduction and Cache Skipping Integration (#3371)

* CK Tile Stream-K Tree Reduction

This change adds the first implementation of the Stream-K tree reduction
strategy into CK Tile. The tree reduction reduces the the number of
steps for accumulating results for a tile from O(N) to O(logN) where N
is the number of workgroups contributing to a C tile.

Additionally, in the original non-atomic reduction strategy, atomics
were used to set the flags buffer and to read from the flags buffer.
Howeover, through investigation with the tree reduciton, atomics with
default (relaxed) semantics were not enough to guarantee workgroups
would not read stale data, leading to incorrect results. Stronger
acquire/release memory orderings are too expensive. So, this change
also eliminates the use of atomics for setting the flags. Instead, we
leverage cache modifiers (e.g., GLC) to avoid writing to cache, thereby
avoiding the use of atomics.

Prelimiary tests were also added for the normal reduction and tree
reduction. More will be added in a future PR via tile engine.

* Move Stream-K kernel files to a subdirectory

* Cleanup Code Style & Handle Unsupported Reductions

This change makes the following small changes:
- Add an explicit else block for unimplemented reduction strategies
- Clarify type of sk_flags_ptr via auto*
- Add description for extra_iters_before_me variable

* Run new copyright script on new files

[ROCm/composable_kernel commit: 22b945e06e]
This commit is contained in:
Emily Martins
2025-12-14 14:49:49 -07:00
committed by GitHub
parent a3270d2eb0
commit eeb78c46a4
13 changed files with 524 additions and 70 deletions

View File

@@ -23,6 +23,9 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
#TODO: support all arches
#TODO: current c-shuffle only supports C layout as R
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
add_gtest_executable(test_ck_tile_streamk_reduction
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp
test_gemm_streamk_util.cpp)
add_gtest_executable(test_ck_tile_streamk_smoke
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16Reduction : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16Reduction
TYPED_TEST_SUITE(TestCkTileStreamKFp16Reduction, KernelTypesStreamKFp16Reduction);
#include "test_gemm_streamk_reduction_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,88 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile_Tree)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu;
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu;
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Tree)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 4;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Reduction)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 4;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles_Tree)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 3;
ck_tile::index_t N = N_Tile * 7;
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 3;
ck_tile::index_t N = N_Tile * 7;
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
}

View File

@@ -33,6 +33,14 @@ using KernelTypesStreamKFp16Persistent = ::testing::Types<
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>
>;
using KernelTypesStreamKFp16Reduction = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>>;
using KernelTypesStreamKBf16Persistent = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,

View File

@@ -144,7 +144,11 @@ class TestCkTileStreamK : public ::testing::Test
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
auto kargs = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
ck_tile::DeviceMem workspace_data(workspace_size);
workspace_data.SetZero();
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -184,11 +188,6 @@ class TestCkTileStreamK : public ::testing::Test
using namespace ck_tile::literals;
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
{
throw std::runtime_error("Reduction Strategy is current unsupported!\n");
}
auto f_host_tensor_descriptor = [](std::size_t row,
std::size_t col,
std::size_t stride,
@@ -252,9 +251,25 @@ class TestCkTileStreamK : public ::testing::Test
stride_B,
stride_C};
ck_tile::index_t num_accumulations_per_tile =
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
ck_tile::index_t num_accumulations_per_tile;
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
{
num_accumulations_per_tile = invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
}
else if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
{
num_accumulations_per_tile =
invoke_streamk<ck_tile::StreamKReductionStrategy::Reduction>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
}
else
{
num_accumulations_per_tile =
invoke_streamk<ck_tile::StreamKReductionStrategy::TreeReduction>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
}
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());

View File

@@ -372,6 +372,85 @@ TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings)
}
}
TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, SKOnlyLargeK)
{
/*
The StreamKTilePartitionerBaseConfigSKOnlyLargeK has the following form:
- tiles in the C tensor: 2
- iters_per_tile: 5
- grid: 5
- dp_tiles: 0
- sk_tiles: 2
- iters_per_sk_cta: 2
- extra_iters: 0
The tiles with iters are as follows:
tile_idx: __________0_________|_________1_________|
tile_iter:| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
| | | | | | | | | | |
<---------------SK Tiles--------------->|
From the above configuration, we get the following:
- SK CTA 0: tile_iter_start is 0 with local CTA index of 0 in tile 0
- SK CTA 1: tile_iter_start is 0 with local CTA index of 1 in tile 0
- SK CTA 2: tile_iter_start is 0 with local CTA index of 2 in tile 0
- SK CTA 2: tile_iter_start is 5 with local CTA index of 0 in tile 1
- SK CTA 3: tile_iter_start is 5 with local CTA index of 1 in tile 1
- SK CTA 4: tile_iter_start is 5 with local CTA index of 2 in tile 1
*/
// Now we create a vector of triplets (tile_iter_start, cta_idx, tile_local_cta_idx) to test
std::vector<std::array<ck_tile::index_t, 3>> sk_triplets{
{0, 0, 0}, {0, 1, 1}, {0, 2, 2}, {5, 2, 0}, {5, 3, 1}, {5, 4, 2}};
for(const auto& triplet : sk_triplets)
{
const auto& [tile_iter_start, cta_idx, tile_local_cta_idx] = triplet;
test_get_tile_local_cta_idx<StreamKTilePartitionerBaseConfigSKOnlyLargeK>(
tile_iter_start, cta_idx, tile_local_cta_idx);
}
}
TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, DP2TileSK)
{
/*
The StreamKTilePartitionerBaseConfigDP2TileSK has the following form:
- tiles in the C tensor: 7
- iters_per_tile: 3
- grid: 3
- dp_tiles: 3
- sk_tiles: 4
- iters_per_sk_cta: 2
- extra_iters: 2
The tiles with iters are as follows:
tile_idx: ____0___|___1___|___2___|___3___|___4___|____5____|____6____|
tile_iter:| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 |
| | | | | | | | | | | | | | |
|<-------DP Tiles------>|<------------SK Tiles------------->|
From the above configuration, we get the following:
- SK CTA 0: tile_iter_start is 6 with local CTA index of 0 in tile 3
- SK CTA 0: tile_iter_start is 8 with local CTA index of 0 in tile 4
- SK CTA 1: tile_iter_start is 8 with local CTA index of 1 in tile 4
- SK CTA 1: tile_iter_start is 10 with local CTA index of 0 in tile 5
- SK CTA 2: tile_iter_start is 12 with local CTA index of 0 in tile 6
*/
// Now we create a vector of triplets (tile_iter_start, cta_idx, tile_local_cta_idx) to test
std::vector<std::array<ck_tile::index_t, 3>> sk_triplets{
{6, 0, 0}, {8, 0, 0}, {8, 1, 1}, {10, 1, 0}, {12, 2, 0}};
for(const auto& triplet : sk_triplets)
{
const auto& [tile_iter_start, cta_idx, tile_local_cta_idx] = triplet;
test_get_tile_local_cta_idx<StreamKTilePartitionerBaseConfigDP2TileSK>(
tile_iter_start, cta_idx, tile_local_cta_idx);
}
}
// Persistent
TEST(StreamKTilePartitioner_PersistentConstructor, SKOnly)
{

View File

@@ -4,6 +4,7 @@
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "gtest/gtest.h"
#include <array>
enum StreamKTilePartitionerBaseMethodId
{
@@ -12,7 +13,8 @@ enum StreamKTilePartitionerBaseMethodId
GET_TILE_BOUNDARIES,
GET_TILE_INDEX,
GET_ITER_BOUNDARIES,
GET_OUTPUT_TILE_INDEX
GET_OUTPUT_TILE_INDEX,
GET_TILE_LOCAL_CTA_INDEX
};
// Base kernel wrapper class to facilitate testing class device functions.
@@ -136,6 +138,22 @@ struct KernelWrapperSpecialized<TilePartitioner,
}
};
template <typename TilePartitioner>
struct KernelWrapperSpecialized<TilePartitioner,
StreamKTilePartitionerBaseMethodId::GET_TILE_LOCAL_CTA_INDEX>
: public KernelWrapper<TilePartitioner>
{
using Base = KernelWrapper<TilePartitioner>;
CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs)
{
ck_tile::index_t tile_local_cta_index =
kargs.tile_partitioner.get_tile_local_cta_index(kargs.arg1, kargs.arg2);
*(static_cast<ck_tile::index_t*>(kargs.result1)) = tile_local_cta_index;
}
};
struct StreamKTilePartitionerBaseExpected
{
ck_tile::index_t sk_tiles_;
@@ -243,6 +261,22 @@ struct StreamKTilePartitionerBaseConfigSKOnly : public StreamKTilePartitionerBas
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
};
struct StreamKTilePartitionerBaseConfigSKOnlyLargeK : public StreamKTilePartitionerBaseConfig
{
static constexpr ck_tile::index_t M = 8;
static constexpr ck_tile::index_t N = 2;
static constexpr ck_tile::index_t K = 10;
static constexpr ck_tile::index_t GRID = 5;
static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 2;
static constexpr ck_tile::index_t K_TILE = 2;
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
};
struct StreamKTilePartitionerBaseConfigEdgeCase : public StreamKTilePartitionerBaseConfig
{
@@ -314,6 +348,38 @@ void test_get_output_tile_index(ck_tile::index_t tile_idx,
EXPECT_EQ(in, in_expected);
};
template <typename Config>
void test_get_tile_local_cta_idx(ck_tile::index_t tile_iter_start,
ck_tile::index_t cta_idx,
ck_tile::index_t expected_tile_local_cta_idx)
{
// Types
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<typename Config::GemmShape>;
using Kernel =
KernelWrapperSpecialized<TilePartitioner,
StreamKTilePartitionerBaseMethodId::GET_TILE_LOCAL_CTA_INDEX>;
// Test parameters
ck_tile::StreamKTilePartitionerBase<typename Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
ck_tile::DeviceMem tile_local_cta_idx_dev(sizeof(ck_tile::index_t));
// Launch kernel
auto kargs = Kernel::MakeKernelArgs(tile_iter_start,
cta_idx,
Config::UNUSED,
tile_local_cta_idx_dev.GetDeviceBuffer(),
nullptr,
tile_partitioner);
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
// Validate results
ck_tile::index_t tile_local_cta_idx;
tile_local_cta_idx_dev.FromDevice(&tile_local_cta_idx);
EXPECT_EQ(tile_local_cta_idx, expected_tile_local_cta_idx);
}
// Configs for TilePartitioner Child structs
struct StreamKTilePartitionerV2PersistentExpected
{