This commit is contained in:
Mateusz Ozga
2025-05-13 10:50:25 +00:00
parent 05b46cd8e5
commit e17ac63e4a
6 changed files with 137 additions and 145 deletions

View File

@@ -1496,8 +1496,10 @@ struct ElementWiseAdd
* @note [return] Perform element-wise addition and store the result in 'r'
*/
template <typename ResT, typename ParamT>
CK_TILE_DEVICE auto operator()(ResT& r, const ParamT& a, const ParamT& b, const ParamT& c) const
-> void
CK_TILE_DEVICE auto operator()(ResT& r,
[[maybe_unused]] const ParamT& a,
[[maybe_unused]] const ParamT& b,
[[maybe_unused]] const ParamT& c) const -> void
{
r = a + b + c;
}
@@ -1536,7 +1538,7 @@ struct ElementWiseMul
CK_TILE_DEVICE auto operator()(ResT& r, const ParamT& a, const ParamT& b, const ParamT& c) const
-> void
{
r = a + b + c;
r = a * b * c;
}
/**

View File

@@ -154,7 +154,7 @@ struct CShuffleEpilogue
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
onst DsDramWindows& ds_dram_window,
const DsDramWindows& ds_dram_window,
void* p_smem)
{
@@ -190,10 +190,6 @@ struct CShuffleEpilogue
[&](auto idx) { return make_tile_window(ds_dram_window[idx], dram_tile_distribution); },
number<NumDTensor>{});
using elemenet_wise_output_t =
decltype(load_tile(make_tile_window(out_lds_window, dram_tile_distribution)));
elemenet_wise_output_t elemenet_wise_output;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
@@ -215,26 +211,26 @@ struct CShuffleEpilogue
store_tile(in_lds_window, c_warp_in_tensor_casted);
block_sync_lds();
const auto c_out_tensor =
auto c_out_tensor =
load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
const auto ds_tensor = generate_tuple(
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
const auto c_ds_tiles = concat_tuple_of_reference(
tie(elemenet_wise_output, c_out_tensor),
tie(c_out_tensor, c_out_tensor),
generate_tie(
[&](auto i) -> const auto& { return ds_tensor[i]; }, number<NumDTensor>{}));
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
tile_elementwise_in_out_unpack_tuple(typename Problem::CDElementwise{}, c_ds_tiles);
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
store_tile(out_dram_window, elemenet_wise_output);
}
else
{
update_tile(out_dram_window, c_out_tensor);
update_tile(out_dram_window, elemenet_wise_output);
}
if constexpr(iAccess != num_access - 1)
{

View File

@@ -9,16 +9,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/core/utility/env.hpp"
namespace ck_tile {
/// @brief The GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments
/// object. It contain all necessary information required to build proper kernel argument
/// and launch kernel on GPU.
template <index_t NumDTensor = 0>
struct GemmHostArgs
{
@@ -64,75 +57,23 @@ struct GemmHostArgs
index_t k_batch;
};
/// @brief The GEMM kernel device arguments.
template <typename DType = tuple<>>
struct GemmKernelArgs
{
/// @brief The A input tensor's pointer to device memory.
const void* a_ptr;
/// @brief The B input tensor's pointer to device memory.
const void* b_ptr;
/// @brief The Ds input tensor's tuple to device memory.
const DType ds_ptr;
/// @brief The C output tensor's pointer to device memory.
void* c_ptr;
/// @brief GEMM's M dimension size.
index_t M;
/// @brief GEMM's N dimension size.
index_t N;
/// @brief GEMM's K dimension size.
index_t K;
/// @brief The distance between consecutive elements of non-contiguous dimension
/// (in memory) of A tensor.
index_t stride_A;
/// @brief The distance between consecutive elements of non-contiguous dimension
/// (in memory) of B tensor.
index_t stride_B;
/// @brief The distance between consecutive elements of non-contiguous dimension
/// (in memory) of Ds tensor.
const index_t* stride_Ds;
/// @brief The distance between consecutive elements of non-contiguous dimension
/// (in memory) of C tensor.
index_t stride_C;
index_t k_batch;
};
/// @brief The GEMM kernel template.
///
/// @paragraph Overview Overview
/// This class provides the generic matrix multiplication kernel template. By semantic
/// division of GEMM algorithm into following parts we achieve flexible, versatile
/// and robust kernel implementation.
///
/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator()
/// function call operator" which determines the work scope of each workgroup.
/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm.
/// This is the place where each workgroup is loading data from global memory and
/// carrying out dot products.
/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation
/// responsible for storing results to global memory. This is also the place where
/// any additional operator fusion may take place.
///
/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_
/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all
/// internal details of those functional parts. You can think of it like both gemm and
/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover
/// the policy is responsible for definition of all necessary data layouts and thread's
/// work distribution.
///
/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into the
/// output data tile to be calculated. It determines the workgroup to
/// data relationship (or in other words - which data would be
/// processed and calculated by which workgroup).
/// @tparam GemmPipeline_ The type of class which provides the core part of matrix
/// multiplication. This class should provide implementation of data
/// loading from global memory and performing block-wise matrix
/// multiplication. You can think of it as a work done by single
/// workgroup point of view.
/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix
/// multiplication implementation. It is responsible for storing
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
/// the output C tensor in global memory.
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct GemmKernel
{
@@ -580,10 +521,9 @@ struct GemmKernel
}
}();
return make_tuple(a_tensor_view,
b_tensor_view,
generate_tuple(d_tensor_view, number<NumDTensor>{}),
c_tensor_view);
const auto& ds_tensor_view = generate_tuple(d_tensor_view, number<NumDTensor>{});
return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
}
template <typename TensorView>
@@ -740,7 +680,9 @@ struct GemmKernel
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
const BDataType* b_ptr,
const DsGridPointer ds_ptr,
@@ -752,10 +694,8 @@ struct GemmKernel
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr,c_ptr, kargs, splitk_batch_offset);
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
@@ -774,9 +714,12 @@ struct GemmKernel
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{}
.template operator()<decltype(c_block_window),
decltype(c_block_tile),
decltype(d_block_window),
DstInMemOp>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
/**
@@ -795,7 +738,9 @@ struct GemmKernel
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr,
const BDataType* b_ptr,
const DsGridPointer ds_ptr,
@@ -808,10 +753,8 @@ struct GemmKernel
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
@@ -829,10 +772,12 @@ struct GemmKernel
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{}
.template operator()<decltype(c_block_window),
decltype(c_block_tile),
decltype(d_block_window),
DstInMemOp>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
CK_TILE_DEVICE void operator()(GemmKernelArgs<DsGridPointer>& kargs) const
@@ -849,7 +794,6 @@ struct GemmKernel
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
// const DsGridPointer* ds_ptr = reinterpret_cast<const DsGridPointer*>(kargs.ds_ptr);
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
@@ -859,9 +803,7 @@ struct GemmKernel
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
if(kargs.k_batch == 1)
{
RunGemm2LDS(a_ptr,
b_ptr,
@@ -874,12 +816,27 @@ struct GemmKernel
i_m,
i_n);
}
else
{
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS<memory_operation_enum::atomic_add>(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
}
else
{
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
if(kargs.k_batch == 1)
{
RunGemm(a_ptr,
b_ptr,
@@ -891,6 +848,22 @@ struct GemmKernel
i_m,
i_n);
}
else
{
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value))
{
RunGemm<memory_operation_enum::atomic_add>(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
}
}
}
};

View File

@@ -2,8 +2,8 @@
TYPED_TEST(TestCkTileMultipleDGemm, Basic)
{
constexpr int M = 256;
constexpr int N = 512;
constexpr int K = 512;
constexpr int M = 3840;
constexpr int N = 4096;
constexpr int K = 4096;
this->Run(M, N, K);
}

View File

@@ -120,10 +120,13 @@ class TestCkTileMultipleDGemm : public ::testing::Test
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
@@ -152,7 +155,8 @@ class TestCkTileMultipleDGemm : public ::testing::Test
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC>>;
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
@@ -180,11 +184,29 @@ class TestCkTileMultipleDGemm : public ::testing::Test
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
if(has_hot_loop)
{
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else

View File

@@ -175,52 +175,51 @@ TYPED_TEST_SUITE(TestGroupedConvndBwdWeight3d, KernelTypes3d);
TYPED_TEST(TestGroupedConvndBwdWeight1d, Test1D)
{
// this->conv_params.clear();
// this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
// this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
// this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
// this->conv_params.push_back({1, 1, 1, 1, 32, {3}, {32}, {1}, {1}, {1}, {1}});
// this->conv_params.push_back({1, 1, 1, 64, 3, {3}, {32}, {1}, {1}, {1}, {1}});
// this->conv_params.push_back({1, 1, 1, 1, 1, {3}, {32}, {1}, {1}, {1}, {1}});
// this->Run();
this->conv_params.clear();
this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}});
this->conv_params.push_back({1, 2, 32, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 2, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}});
this->conv_params.push_back({1, 1, 1, 1, 32, {3}, {32}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 1, 1, 64, 3, {3}, {32}, {1}, {1}, {1}, {1}});
this->conv_params.push_back({1, 1, 1, 1, 1, {3}, {32}, {1}, {1}, {1}, {1}});
this->Run();
}
TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
{
this->conv_params.clear();
this->conv_params.push_back(
{2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
// this->conv_params.push_back(
// {2, 2, 64, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
// this->conv_params.push_back({2, 2, 64, 3, 3, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0,
// 0}}); this->conv_params.push_back({2, 2, 64, 5, 5, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0},
// {0, 0}}); this->conv_params.push_back(
// {2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
// this->conv_params.push_back(
// {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
// this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1,
// 1}}); this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1},
// {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1,
// 1}, {1, 1}}); this->conv_params.push_back(
// {2, 16, 16, 1, 1, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 2, 64, 5, 5, {1, 1}, {7, 7}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back(
{2, 2, 4, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 16, 16, 1, 1, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}});
this->Run();
}
TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
{
// this->conv_params.clear();
// this->conv_params.push_back(
// {3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
// this->conv_params.push_back(
// {3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
// this->conv_params.push_back(
// {3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
// this->conv_params.push_back(
// {3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
// this->conv_params.push_back(
// {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
// this->conv_params.push_back(
// {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
// this->conv_params.push_back(
// {3, 16, 16, 1, 1, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
// this->Run();
}
this->conv_params.clear();
this->conv_params.push_back(
{3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
this->conv_params.push_back(
{3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 16, 16, 1, 1, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->Run();
}