[CK_TILE] Multiple-D GEMM example (#2219)

* Multiple d, initial commit

* Check Ds Layout

* Readme and clang format

* Update branch & conflicts

* Multiple D - fix clang-formatter

* Rename elemetwise_op

* Fix CI

* Code review part1

* Remove printf

* Remove unnecessary comment

* Add new tests with Col layout

* Review part 2

* Added support for Multiple D GEMM

* Update comment

* Remove maybe_unused

* Clang-format

* Review part 3

* Add comment to function

* Add comment to function: another

* Take number of params for a refrence function

* Remove additional d param for 0 tensor

* Change name of function

* Fix CI fails
This commit is contained in:
Mateusz Ozga
2025-06-13 19:39:11 +02:00
committed by GitHub
parent 3a0cb27966
commit bd96ac9742
34 changed files with 2267 additions and 285 deletions

View File

@@ -11,9 +11,12 @@ namespace ck_tile {
template <typename ADataType_,
typename BDataType_,
typename DsDataType_,
typename AccDataType_,
typename ODataType_,
typename CLayout_,
typename DsLayout_,
typename ELayout_,
typename CDElementwise_,
index_t kBlockSize_,
index_t kM_,
index_t kN_,
@@ -31,7 +34,10 @@ struct CShuffleEpilogueProblem
using BDataType = remove_cvref_t<BDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using CLayout = remove_cvref_t<CLayout_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
@@ -43,6 +49,10 @@ struct CShuffleEpilogueProblem
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
};
template <typename Problem_, typename Policy_ = void>
@@ -53,10 +63,13 @@ struct CShuffleEpilogue
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using ELayout = remove_cvref_t<typename Problem::ELayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
@@ -69,7 +82,10 @@ struct CShuffleEpilogue
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
/**
* @brief Get the vector store size for C tensor.
*
@@ -83,22 +99,49 @@ struct CShuffleEpilogue
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeC()
{
constexpr index_t max_vector_size = 16;
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(NPerIteration),
static_cast<int>(max_vector_size / sizeof(ODataType)));
}
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(MPerIteration),
static_cast<int>(max_vector_size / sizeof(ODataType)));
}
else
{
static_assert(false, "Unsupported CLayout!");
static_assert(false, "Unsupported ELayout!");
}
}
/**
* @brief Get the vector store size for Di tensor.
*
* @return The vector store size for Di tensor.
*/
template <index_t I>
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeD(number<I> index)
{
constexpr index_t max_vector_size = 16;
using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return std::min(static_cast<int>(NPerIteration),
static_cast<int>(max_vector_size / sizeof(DiDataType)));
}
else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
{
return std::min(static_cast<int>(MPerIteration),
static_cast<int>(max_vector_size / sizeof(DiDataType)));
}
else
{
static_assert(false, "Unsupported DLayout!");
}
return max_vector_size / sizeof(DiDataType);
}
/**
* @brief Shuffle tile configuration parameters
*
@@ -116,7 +159,7 @@ struct CShuffleEpilogue
else
{
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
(kMPerBlock % num_xdl_shuffles == 0),
@@ -147,7 +190,8 @@ struct CShuffleEpilogue
}();
static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
using WG = WarpGemmMfmaDispatcher<ADataType,
using WG = WarpGemmMfmaDispatcher<ADataType,
BTypeToUse,
AccDataType,
MPerXdl,
@@ -162,14 +206,14 @@ struct CShuffleEpilogue
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
{
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<NPerIterationShuffle>{}, number<1>{}));
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
{
return make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
@@ -177,7 +221,7 @@ struct CShuffleEpilogue
}
else
{
static_assert(false, "Unsupported CLayout!");
static_assert(false, "Unsupported ELayout!");
}
}
@@ -202,9 +246,11 @@ struct CShuffleEpilogue
return MPerIterationShuffle * NPerIterationShuffle * sizeof(ODataType);
}
template <typename ODramWindow, typename OAccTile>
CK_TILE_DEVICE auto
operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem)
template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
const OAccTile& o_acc_tile,
const DsDramWindows& ds_dram_windows,
void* p_smem)
{
constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
@@ -230,7 +276,7 @@ struct CShuffleEpilogue
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
constexpr index_t num_access = SFC::get_num_of_access();
static_assert(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>,
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
"Currently, the CShuffle Epilogue only supports the Row Major Output layout");
using TileEncodingPattern =
@@ -242,6 +288,12 @@ struct CShuffleEpilogue
Problem::kNumWaveGroups>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
auto d_dram_windows = generate_tuple(
[&](auto idx) {
return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
},
number<NumDTensor>{});
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
@@ -265,8 +317,17 @@ struct CShuffleEpilogue
store_tile(in_lds_window, c_warptile_in_tensor_casted);
block_sync_lds();
const auto c_out_tensor =
load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
const auto ds_tensor = generate_tuple(
[&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
const auto c_ds_tiles = concat_tuple_of_reference(
tie(c_out_tensor, c_out_tensor),
generate_tie(
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
if constexpr(MemoryOperation == memory_operation_enum::set)
{
@@ -279,7 +340,13 @@ struct CShuffleEpilogue
if constexpr(iAccess != num_access - 1)
{
constexpr auto step = SFC::get_forward_step(iAccess);
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
static_for<0, NumDTensor, 1>{}([&](auto idx) {
move_tile_window(d_dram_windows[idx],
{step.at(number<0>{}), step.at(number<1>{})});
});
}
});
}