mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
[CK_TILE] Stream-K Tree Reduction and Cache Skipping Integration (#3371)
* CK Tile Stream-K Tree Reduction
This change adds the first implementation of the Stream-K tree reduction
strategy into CK Tile. The tree reduction reduces the the number of
steps for accumulating results for a tile from O(N) to O(logN) where N
is the number of workgroups contributing to a C tile.
Additionally, in the original non-atomic reduction strategy, atomics
were used to set the flags buffer and to read from the flags buffer.
Howeover, through investigation with the tree reduciton, atomics with
default (relaxed) semantics were not enough to guarantee workgroups
would not read stale data, leading to incorrect results. Stronger
acquire/release memory orderings are too expensive. So, this change
also eliminates the use of atomics for setting the flags. Instead, we
leverage cache modifiers (e.g., GLC) to avoid writing to cache, thereby
avoiding the use of atomics.
Prelimiary tests were also added for the normal reduction and tree
reduction. More will be added in a future PR via tile engine.
* Move Stream-K kernel files to a subdirectory
* Cleanup Code Style & Handle Unsupported Reductions
This change makes the following small changes:
- Add an explicit else block for unimplemented reduction strategies
- Clarify type of sk_flags_ptr via auto*
- Add description for extra_iters_before_me variable
* Run new copyright script on new files
[ROCm/composable_kernel commit: 22b945e06e]
This commit is contained in:
@@ -23,6 +23,9 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
|
||||
#TODO: support all arches
|
||||
#TODO: current c-shuffle only supports C layout as R
|
||||
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_reduction
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp
|
||||
test_gemm_streamk_util.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_smoke
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16Reduction : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16Reduction
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16Reduction, KernelTypesStreamKFp16Reduction);
|
||||
|
||||
#include "test_gemm_streamk_reduction_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile_Tree)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile * num_cu;
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile * num_cu;
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Tree)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile * 4;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Reduction)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile * 4;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles_Tree)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile * 3;
|
||||
ck_tile::index_t N = N_Tile * 7;
|
||||
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile * 3;
|
||||
ck_tile::index_t N = N_Tile * 7;
|
||||
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
|
||||
}
|
||||
@@ -33,6 +33,14 @@ using KernelTypesStreamKFp16Persistent = ::testing::Types<
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp16Reduction = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>>;
|
||||
|
||||
using KernelTypesStreamKBf16Persistent = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
|
||||
|
||||
@@ -144,7 +144,11 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
|
||||
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
|
||||
ck_tile::DeviceMem workspace_data(workspace_size);
|
||||
workspace_data.SetZero();
|
||||
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
@@ -184,11 +188,6 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
|
||||
using namespace ck_tile::literals;
|
||||
|
||||
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
throw std::runtime_error("Reduction Strategy is current unsupported!\n");
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
@@ -252,9 +251,25 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
stride_B,
|
||||
stride_C};
|
||||
|
||||
ck_tile::index_t num_accumulations_per_tile =
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
ck_tile::index_t num_accumulations_per_tile;
|
||||
|
||||
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
num_accumulations_per_tile = invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
|
||||
}
|
||||
else if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
num_accumulations_per_tile =
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::Reduction>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
|
||||
}
|
||||
else
|
||||
{
|
||||
num_accumulations_per_tile =
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::TreeReduction>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
|
||||
}
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
|
||||
@@ -372,6 +372,85 @@ TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings)
|
||||
}
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, SKOnlyLargeK)
|
||||
{
|
||||
/*
|
||||
The StreamKTilePartitionerBaseConfigSKOnlyLargeK has the following form:
|
||||
- tiles in the C tensor: 2
|
||||
- iters_per_tile: 5
|
||||
- grid: 5
|
||||
- dp_tiles: 0
|
||||
- sk_tiles: 2
|
||||
- iters_per_sk_cta: 2
|
||||
- extra_iters: 0
|
||||
|
||||
The tiles with iters are as follows:
|
||||
|
||||
tile_idx: __________0_________|_________1_________|
|
||||
tile_iter:| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
|
||||
| | | | | | | | | | |
|
||||
<---------------SK Tiles--------------->|
|
||||
|
||||
From the above configuration, we get the following:
|
||||
- SK CTA 0: tile_iter_start is 0 with local CTA index of 0 in tile 0
|
||||
- SK CTA 1: tile_iter_start is 0 with local CTA index of 1 in tile 0
|
||||
- SK CTA 2: tile_iter_start is 0 with local CTA index of 2 in tile 0
|
||||
- SK CTA 2: tile_iter_start is 5 with local CTA index of 0 in tile 1
|
||||
- SK CTA 3: tile_iter_start is 5 with local CTA index of 1 in tile 1
|
||||
- SK CTA 4: tile_iter_start is 5 with local CTA index of 2 in tile 1
|
||||
*/
|
||||
|
||||
// Now we create a vector of triplets (tile_iter_start, cta_idx, tile_local_cta_idx) to test
|
||||
std::vector<std::array<ck_tile::index_t, 3>> sk_triplets{
|
||||
{0, 0, 0}, {0, 1, 1}, {0, 2, 2}, {5, 2, 0}, {5, 3, 1}, {5, 4, 2}};
|
||||
|
||||
for(const auto& triplet : sk_triplets)
|
||||
{
|
||||
const auto& [tile_iter_start, cta_idx, tile_local_cta_idx] = triplet;
|
||||
test_get_tile_local_cta_idx<StreamKTilePartitionerBaseConfigSKOnlyLargeK>(
|
||||
tile_iter_start, cta_idx, tile_local_cta_idx);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, DP2TileSK)
|
||||
{
|
||||
/*
|
||||
The StreamKTilePartitionerBaseConfigDP2TileSK has the following form:
|
||||
- tiles in the C tensor: 7
|
||||
- iters_per_tile: 3
|
||||
- grid: 3
|
||||
- dp_tiles: 3
|
||||
- sk_tiles: 4
|
||||
- iters_per_sk_cta: 2
|
||||
- extra_iters: 2
|
||||
|
||||
The tiles with iters are as follows:
|
||||
|
||||
tile_idx: ____0___|___1___|___2___|___3___|___4___|____5____|____6____|
|
||||
tile_iter:| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 |
|
||||
| | | | | | | | | | | | | | |
|
||||
|<-------DP Tiles------>|<------------SK Tiles------------->|
|
||||
|
||||
From the above configuration, we get the following:
|
||||
- SK CTA 0: tile_iter_start is 6 with local CTA index of 0 in tile 3
|
||||
- SK CTA 0: tile_iter_start is 8 with local CTA index of 0 in tile 4
|
||||
- SK CTA 1: tile_iter_start is 8 with local CTA index of 1 in tile 4
|
||||
- SK CTA 1: tile_iter_start is 10 with local CTA index of 0 in tile 5
|
||||
- SK CTA 2: tile_iter_start is 12 with local CTA index of 0 in tile 6
|
||||
*/
|
||||
|
||||
// Now we create a vector of triplets (tile_iter_start, cta_idx, tile_local_cta_idx) to test
|
||||
std::vector<std::array<ck_tile::index_t, 3>> sk_triplets{
|
||||
{6, 0, 0}, {8, 0, 0}, {8, 1, 1}, {10, 1, 0}, {12, 2, 0}};
|
||||
|
||||
for(const auto& triplet : sk_triplets)
|
||||
{
|
||||
const auto& [tile_iter_start, cta_idx, tile_local_cta_idx] = triplet;
|
||||
test_get_tile_local_cta_idx<StreamKTilePartitionerBaseConfigDP2TileSK>(
|
||||
tile_iter_start, cta_idx, tile_local_cta_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Persistent
|
||||
TEST(StreamKTilePartitioner_PersistentConstructor, SKOnly)
|
||||
{
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include <array>
|
||||
|
||||
enum StreamKTilePartitionerBaseMethodId
|
||||
{
|
||||
@@ -12,7 +13,8 @@ enum StreamKTilePartitionerBaseMethodId
|
||||
GET_TILE_BOUNDARIES,
|
||||
GET_TILE_INDEX,
|
||||
GET_ITER_BOUNDARIES,
|
||||
GET_OUTPUT_TILE_INDEX
|
||||
GET_OUTPUT_TILE_INDEX,
|
||||
GET_TILE_LOCAL_CTA_INDEX
|
||||
};
|
||||
|
||||
// Base kernel wrapper class to facilitate testing class device functions.
|
||||
@@ -136,6 +138,22 @@ struct KernelWrapperSpecialized<TilePartitioner,
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner>
|
||||
struct KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_TILE_LOCAL_CTA_INDEX>
|
||||
: public KernelWrapper<TilePartitioner>
|
||||
{
|
||||
|
||||
using Base = KernelWrapper<TilePartitioner>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs)
|
||||
{
|
||||
ck_tile::index_t tile_local_cta_index =
|
||||
kargs.tile_partitioner.get_tile_local_cta_index(kargs.arg1, kargs.arg2);
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result1)) = tile_local_cta_index;
|
||||
}
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseExpected
|
||||
{
|
||||
ck_tile::index_t sk_tiles_;
|
||||
@@ -243,6 +261,22 @@ struct StreamKTilePartitionerBaseConfigSKOnly : public StreamKTilePartitionerBas
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigSKOnlyLargeK : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 8;
|
||||
static constexpr ck_tile::index_t N = 2;
|
||||
static constexpr ck_tile::index_t K = 10;
|
||||
static constexpr ck_tile::index_t GRID = 5;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 2;
|
||||
static constexpr ck_tile::index_t K_TILE = 2;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigEdgeCase : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
|
||||
@@ -314,6 +348,38 @@ void test_get_output_tile_index(ck_tile::index_t tile_idx,
|
||||
EXPECT_EQ(in, in_expected);
|
||||
};
|
||||
|
||||
template <typename Config>
|
||||
void test_get_tile_local_cta_idx(ck_tile::index_t tile_iter_start,
|
||||
ck_tile::index_t cta_idx,
|
||||
ck_tile::index_t expected_tile_local_cta_idx)
|
||||
{
|
||||
// Types
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<typename Config::GemmShape>;
|
||||
using Kernel =
|
||||
KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_TILE_LOCAL_CTA_INDEX>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<typename Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
ck_tile::DeviceMem tile_local_cta_idx_dev(sizeof(ck_tile::index_t));
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(tile_iter_start,
|
||||
cta_idx,
|
||||
Config::UNUSED,
|
||||
tile_local_cta_idx_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
tile_partitioner);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t tile_local_cta_idx;
|
||||
tile_local_cta_idx_dev.FromDevice(&tile_local_cta_idx);
|
||||
EXPECT_EQ(tile_local_cta_idx, expected_tile_local_cta_idx);
|
||||
}
|
||||
|
||||
// Configs for TilePartitioner Child structs
|
||||
struct StreamKTilePartitionerV2PersistentExpected
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user