[CK TILE] Implement cschuflle algorithm (#1842)

* [CK TILE] Implement cschuflle algorithm

* Rebase

* Vector store size fixes

* fixes

* Fixes

* fixes

* fmha fix

* fixes

* fixes of fixes
This commit is contained in:
Bartłomiej Kocot
2025-01-30 11:57:39 +01:00
committed by GitHub
parent c5fff071e5
commit 25e2e0f04a
18 changed files with 408 additions and 371 deletions

View File

@@ -1,9 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile {
@@ -23,6 +25,26 @@ struct Default2DEpilogueProblem
static constexpr bool UseRawStore = UseRawStore_;
};
template <typename AccDataType_,
typename ODataType_,
typename CLayout_,
bool kPadM_,
bool kPadN_,
index_t kMPerXdl_,
index_t kNPerXdl_,
index_t kKPerXdl_,
bool isCTransposed_,
bool UseRawStore_ = true>
struct DefaultGemm2DEpilogueProblem
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
{
using CLayout = remove_cvref_t<CLayout_>;
static constexpr index_t kMPerXdl = kMPerXdl_;
static constexpr index_t kNPerXdl = kNPerXdl_;
static constexpr index_t kKPerXdl = kKPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
};
template <typename Problem_, typename Policy_ = void>
struct Default2DEpilogue
{
@@ -35,14 +57,13 @@ struct Default2DEpilogue
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() { return false; }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template <typename ODramWindowTmp,
typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile)
CK_TILE_DEVICE auto
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
{
// TODO: this is ugly
@@ -71,4 +92,76 @@ struct Default2DEpilogue
}
}
};
template <typename Problem_, typename Policy_ = void>
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
using WG = WarpGemmMfmaDispatcher<ODataType,
ODataType,
AccDataType,
kMPerXdl,
kNPerXdl,
kKPerXdl,
isCTransposed>;
using CWarpDstr = typename WG::CWarpDstr;
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(isCTransposed)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
{
// In this case each thread has just a single item in Ndim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(isCTransposed)
{
// In this case each thread has just a single item in Mdim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
};
} // namespace ck_tile