mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Resolve conflict by accepting toy branch version
This commit is contained in:
@@ -21,7 +21,11 @@ struct BlockGemmASmemBSmemCReg
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using WarpGemm = remove_cvref_t<
|
||||
decltype(Policy::template GetWarpGemmMWarpNWarp<Problem>().template at<0>())>;
|
||||
decltype(Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<0>())>;
|
||||
static constexpr index_t MWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<1>();
|
||||
static constexpr index_t NWarp =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem>().template get<2>();
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
@@ -42,15 +46,11 @@ struct BlockGemmASmemBSmemCReg
|
||||
static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
// A block tile distribution for load from lds
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
constexpr index_t MIterPerWarp = BlockGemmShape::kM / (MWarp * WarpGemm::kM);
|
||||
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
@@ -70,11 +70,7 @@ struct BlockGemmASmemBSmemCReg
|
||||
// B block tile distribution for load from lds
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
constexpr index_t NIterPerWarp = BlockGemmShape::kN / (NWarp * WarpGemm::kN);
|
||||
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
@@ -99,24 +95,24 @@ struct BlockGemmASmemBSmemCReg
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_warp_tile_;
|
||||
ALdsTile b_warp_tile_;
|
||||
ALdsTile aWarpTile;
|
||||
BLdsTile bWarpTile;
|
||||
|
||||
// Prefetch from LDS to warp register
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
|
||||
const BSmemBlockWindow& b_block_window)
|
||||
{
|
||||
load_tile(a_warp_tile_, a_block_window);
|
||||
load_tile(b_warp_tile_, b_block_window);
|
||||
aWarpTile = load_tile(a_block_window);
|
||||
bWarpTile = load_tile(b_block_window);
|
||||
}
|
||||
#endif
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockWindowTmp& a_block_window_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
[[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
|
||||
@@ -131,17 +127,11 @@ struct BlockGemmASmemBSmemCReg
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
#if !defined(ENABLE_PREFETCH)
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
@@ -149,13 +139,13 @@ struct BlockGemmASmemBSmemCReg
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// construct A-warp-window
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WG::kM,
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
@@ -165,19 +155,18 @@ struct BlockGemmASmemBSmemCReg
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// construct B-warp-window
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN,
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
@@ -187,48 +176,46 @@ struct BlockGemmASmemBSmemCReg
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
#endif
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
#pragma message("local data share prefetch")
|
||||
// read A warp tensor from A block tensor
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
#pragma message("local data share prefetch")
|
||||
a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
#else
|
||||
// read A warp tensor from A block window
|
||||
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
#endif
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
// read B warp tensor from B block tensor
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
#else
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
#endif
|
||||
// read C warp tensor from C block tensor
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
// Warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
@@ -240,8 +227,8 @@ struct BlockGemmASmemBSmemCReg
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockWindowTmp& a_block_window_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType>,
|
||||
@@ -255,17 +242,11 @@ struct BlockGemmASmemBSmemCReg
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get(number<0>{}))>;
|
||||
|
||||
constexpr index_t MWarp = config.template get(number<1>{});
|
||||
constexpr index_t NWarp = config.template get(number<2>{});
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
#if !defined(ENABLE_PREFETCH)
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
@@ -273,13 +254,13 @@ struct BlockGemmASmemBSmemCReg
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// construct A-warp-window
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WG::kM,
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
{a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM,
|
||||
a_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
|
||||
make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_tmp), KIterPerWarp>,
|
||||
@@ -289,19 +270,18 @@ struct BlockGemmASmemBSmemCReg
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// construct B-warp-window
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WG::kN,
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
{b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN,
|
||||
b_block_window_tmp.get_window_origin().at(number<1>{})},
|
||||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{}));
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
@@ -311,13 +291,13 @@ struct BlockGemmASmemBSmemCReg
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
#endif
|
||||
|
||||
static_assert(std::is_same_v<CDataType, typename WG::CDataType>, "wrong!");
|
||||
static_assert(std::is_same_v<CDataType, typename WarpGemm::CDataType>, "wrong!");
|
||||
|
||||
// Construct C-Block-Tensor
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
@@ -329,44 +309,42 @@ struct BlockGemmASmemBSmemCReg
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
// hot loop:
|
||||
// Hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
// read A warp tensor from A block tensor
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_warp_tile_.get_y_sliced_thread_data(
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
#else
|
||||
// read A warp tensor from A block window
|
||||
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
#endif
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
// read B warp tensor from B block tensor
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_warp_tile_.get_y_sliced_thread_data(
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
#else
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
#endif
|
||||
// read C warp tensor from C block tensor
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
// warp GEMM
|
||||
// Warp GEMM
|
||||
if constexpr(KIterPerWarp == 0)
|
||||
{
|
||||
// c = a * b
|
||||
c_warp_tensor = WG{}(a_warp_tensor, b_warp_tensor);
|
||||
c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -375,10 +353,10 @@ struct BlockGemmASmemBSmemCReg
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
|
||||
@@ -17,19 +17,29 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
#if defined(ADJUST_BLOCK_TILE_SHAPE)
|
||||
constexpr index_t kMWarp = 2;
|
||||
constexpr index_t kNWarp = 2;
|
||||
#else
|
||||
constexpr index_t kMWarp = 4;
|
||||
constexpr index_t kNWarp = 1;
|
||||
#endif
|
||||
|
||||
#if defined(NAIVE_IMPLEMENTATION)
|
||||
#pragma message("mfma m32 n32 k8")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 2, 2);
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 2, 2);
|
||||
return make_tuple(
|
||||
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
#elif defined(USING_MFMA_32x32x_8x2)
|
||||
#pragma message("mfma m32 n32 k16")
|
||||
@@ -37,13 +47,15 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 2, 2);
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 2, 2);
|
||||
return make_tuple(
|
||||
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
#elif defined(USING_MFMA_16x16x16)
|
||||
#pragma message("mfma m16 n16 k16")
|
||||
@@ -51,13 +63,15 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 2);
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, 2, 2);
|
||||
return make_tuple(
|
||||
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
#elif defined(USING_MFMA_16x16x_16x2)
|
||||
#pragma message("mfma m16 n16 k32")
|
||||
@@ -65,13 +79,15 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 2);
|
||||
return make_tuple(
|
||||
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, 2, 2);
|
||||
return make_tuple(
|
||||
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, kMWarp, kNWarp);
|
||||
}
|
||||
#endif
|
||||
else
|
||||
|
||||
@@ -42,23 +42,8 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
}
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
static constexpr index_t APackedSize =
|
||||
static constexpr index_t kPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
@@ -74,35 +59,35 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = 64;
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
constexpr index_t WaveNumM = BlockGemm::MWarp;
|
||||
constexpr index_t WaveNumN = BlockGemm::NWarp;
|
||||
|
||||
constexpr index_t AB_LDS_RW_Width = GetSmemPack();
|
||||
|
||||
constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
|
||||
kMPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeA());
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
|
||||
kNPerBlock * kKPerBlock / (kBlockSize * GetVectorSizeB());
|
||||
|
||||
constexpr index_t A_LDS_Write_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width);
|
||||
kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width);
|
||||
kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
|
||||
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width);
|
||||
WaveNumN * kMPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * NPerBlock * KPerBlock / (BlockSize * AB_LDS_RW_Width);
|
||||
WaveNumM * kNPerBlock * kKPerBlock / (kBlockSize * AB_LDS_RW_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
||||
constexpr index_t C_MFMA_Inst_Num = kMPerBlock * kNPerBlock * kKPerBlock /
|
||||
(kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
// A/B split schedule
|
||||
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
|
||||
constexpr auto num_ds_read_inst_a = AB_LDS_RW_Width * sizeof(ADataType) / APackedSize == 16
|
||||
constexpr auto num_ds_read_inst_a = AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16
|
||||
? A_LDS_Read_Inst_Num
|
||||
: A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b = AB_LDS_RW_Width * sizeof(BDataType) / BPackedSize == 16
|
||||
constexpr auto num_ds_read_inst_b = AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16
|
||||
? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
@@ -116,9 +101,9 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
AB_LDS_RW_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
|
||||
AB_LDS_RW_Width * sizeof(ADataType) / kPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
AB_LDS_RW_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4;
|
||||
AB_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
@@ -275,18 +260,18 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
{0, 0},
|
||||
b_copy_dram_window.get_tile_distribution());
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block,
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()));
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block,
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()));
|
||||
#else
|
||||
@@ -313,59 +298,63 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
|
||||
ABlockTile a_block_tile;
|
||||
BBlockTile b_block_tile;
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
|
||||
// -------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock);
|
||||
|
||||
// Prefetch
|
||||
// Global read 0
|
||||
load_tile(a_block_tile, a_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
load_tile(b_block_tile, b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
#pragma message("global prefetch")
|
||||
// Prefetch
|
||||
// Global read 0
|
||||
load_tile(a_block_tile, a_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
load_tile(b_block_tile, b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
block_sync_lds();
|
||||
if(num_loop > 1)
|
||||
{
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// Prefetch from LDS to warp register in block gemm
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
// LDS write 0
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// Prefetch from LDS to warp register in block gemm
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Main body
|
||||
if constexpr(HasHotLoop)
|
||||
if(num_loop > 2)
|
||||
{
|
||||
index_t i = 0;
|
||||
index_t iCounter = 0;
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write 0
|
||||
// LDS write 1
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 0
|
||||
load_tile(a_block_tile, a_copy_dram_window);
|
||||
// Global read 2
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
load_tile(b_block_tile, b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
@@ -375,116 +364,37 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
// Prefetch from LDS to warp register in block gemm
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
HotLoopScheduler();
|
||||
#endif
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 2));
|
||||
iCounter += 1;
|
||||
} while(iCounter < (num_loop - 2));
|
||||
}
|
||||
|
||||
// Tail
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
if(num_loop > 1)
|
||||
{
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
#elif defined(ENABLE_PREFETCH)
|
||||
// Prefetch
|
||||
// Global read 0
|
||||
load_tile(a_block_tile, a_copy_dram_window);
|
||||
load_tile(b_block_tile, b_copy_dram_window);
|
||||
|
||||
{
|
||||
// Move to 1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
// Global read 1
|
||||
load_tile(a_block_tile, a_copy_dram_window);
|
||||
|
||||
// LDS write 0
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
// Global read 1
|
||||
load_tile(b_block_tile, b_copy_dram_window);
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 2;
|
||||
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// Move to i + 2
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
// Global read i + 2
|
||||
load_tile(a_block_tile, a_copy_dram_window);
|
||||
|
||||
// LDS write i + 1
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
// Global read i + 2
|
||||
load_tile(b_block_tile, b_copy_dram_window);
|
||||
|
||||
iCounter--;
|
||||
|
||||
} while(iCounter > 0);
|
||||
|
||||
// Tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 2
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write num_loop - 1
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 1
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
#else
|
||||
// non-prefetch
|
||||
load_tile(a_block_tile, a_copy_dram_window);
|
||||
load_tile(b_block_tile, b_copy_dram_window);
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
index_t iCounter = num_loop - 1;
|
||||
index_t iCounter = num_loop;
|
||||
|
||||
while(iCounter > 0)
|
||||
{
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
load_tile(a_block_tile, a_copy_dram_window);
|
||||
load_tile(b_block_tile, b_copy_dram_window);
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
store_tile(a_copy_lds_window, a_block_tile);
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
|
||||
@@ -313,26 +313,18 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, KPerBlock>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
|
||||
}
|
||||
|
||||
|
||||
@@ -29,28 +29,6 @@ struct GridGemmProblem
|
||||
using CElementFunction = CElementFunction_;
|
||||
};
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
template <typename BlockTile_,
|
||||
typename BlockWarps_,
|
||||
typename WarpTile_,
|
||||
bool PermuteA_ = false,
|
||||
bool PermuteB_ = false>
|
||||
struct TileGemmShape
|
||||
{
|
||||
using BlockTile = remove_cvref_t<BlockTile_>;
|
||||
using BlockWarps = remove_cvref_t<BlockWarps_>;
|
||||
using WarpTile = remove_cvref_t<WarpTile_>;
|
||||
|
||||
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
|
||||
static constexpr index_t kM = BlockTile::at(number<0>{});
|
||||
static constexpr index_t kN = BlockTile::at(number<1>{});
|
||||
static constexpr index_t kK = BlockTile::at(number<2>{});
|
||||
|
||||
static constexpr bool PermuteA = PermuteA_;
|
||||
static constexpr bool PermuteB = PermuteB_;
|
||||
};
|
||||
#else
|
||||
template <index_t kMPerTile, index_t kNPerTile, index_t kKPerTile>
|
||||
struct TileGemmShape
|
||||
{
|
||||
@@ -58,71 +36,7 @@ struct TileGemmShape
|
||||
static constexpr index_t kN = kNPerTile;
|
||||
static constexpr index_t kK = kKPerTile;
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool DoubleSmemBuffer_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
bool TransposeC_ = false>
|
||||
struct TileGemmTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using CLayout = CLayout_;
|
||||
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct BlockGemmPipelineProblem
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Traits::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename Traits::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kPadM = Traits::kPadM;
|
||||
static constexpr bool kPadN = Traits::kPadN;
|
||||
static constexpr bool kPadK = Traits::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
|
||||
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
};
|
||||
#else
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
@@ -137,7 +51,6 @@ struct BlockGemmPipelineProblem
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
};
|
||||
#endif
|
||||
|
||||
// C = A * B
|
||||
template <typename ADataType,
|
||||
@@ -234,60 +147,15 @@ struct Gemm
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
static constexpr index_t M_Warp = 4;
|
||||
static constexpr index_t N_Warp = 1;
|
||||
static constexpr index_t K_Warp = 1;
|
||||
static constexpr index_t M_Warp_Tile = 16;
|
||||
static constexpr index_t N_Warp_Tile = 16;
|
||||
static constexpr index_t K_Warp_Tile = 32;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
static constexpr bool TransposeC = false;
|
||||
#endif
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline()
|
||||
{
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
// Block GEMM pipeline w/ instruction scheduling
|
||||
using GemmShape = TileGemmShape<sequence<kMPerBlock, kNPerBlock, kKPerBlock>,
|
||||
sequence<M_Warp, N_Warp, K_Warp>,
|
||||
sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>,
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
|
||||
using GemmTraits = TileGemmTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
/* ALayout */ tensor_layout::gemm::RowMajor,
|
||||
/* BLayout */ tensor_layout::gemm::ColumnMajor,
|
||||
/* CLayout */ tensor_layout::gemm::RowMajor,
|
||||
TransposeC>;
|
||||
|
||||
using BlockGemmPipelineProblem_ =
|
||||
BlockGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmTraits,
|
||||
GemmPipelineScheduler::Intrawave,
|
||||
/* Has hot loop */ true,
|
||||
TailNumber::Full>;
|
||||
#else
|
||||
using BlockGemmPipelineProblem_ =
|
||||
BlockGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
kBlockSize,
|
||||
TileGemmShape<kMPerBlock, kNPerBlock, kKPerBlock>>;
|
||||
|
||||
#endif
|
||||
return BlockGemmPipelineAGmemBGmemCReg<BlockGemmPipelineProblem_>{};
|
||||
}
|
||||
};
|
||||
|
||||
6
example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt
Executable file → Normal file
6
example/ck_tile/99_toy_example/03_flash_attention_fwd/CMakeLists.txt
Executable file → Normal file
@@ -10,6 +10,12 @@ set(EXAMPLE_REDUCE_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
|
||||
option(ENABLE_TOY_FA_FWD_OPT "Enable toy FA fwd optimization" OFF)
|
||||
if(ENABLE_TOY_FA_FWD_OPT)
|
||||
message("Compiling with toy FA fwd optimization")
|
||||
target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE TOY_FA_FWD_OPT)
|
||||
endif()
|
||||
|
||||
target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
|
||||
@@ -26,6 +26,251 @@ struct BlockGemmARegBSmemCRegV1
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
|
||||
// B block tile distribution for load from lds
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto config =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem, Problem::BlockGemmShape::kM>();
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t NIterPerWarp = Problem::BlockGemmShape::kN / (NWarp * WG::kN);
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
template <index_t VectorSizeB = 8, index_t SmemPack = 8>
|
||||
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
constexpr index_t MPerXDL = WG::kM;
|
||||
constexpr index_t NPerXDL = WG::kN;
|
||||
constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNumM = config.template get<1>();
|
||||
|
||||
constexpr index_t B_LDS_RW_Width = SmemPack;
|
||||
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (kBlockSize * VectorSizeB);
|
||||
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width);
|
||||
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
// B split schedule
|
||||
constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16
|
||||
? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
|
||||
|
||||
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
|
||||
constexpr auto num_dsread_b_mfma =
|
||||
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
|
||||
// stage 1
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma);
|
||||
constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b);
|
||||
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
|
||||
constexpr auto num_mfma_per_dswrite_b =
|
||||
(num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1;
|
||||
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008,
|
||||
num_mfma_per_issue - num_mfma_per_dswrite_b *
|
||||
num_dswrite_per_issue_b,
|
||||
0); // MFMA
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100,
|
||||
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
|
||||
ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp>
|
||||
__device__ void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BLdsTile& b_block_tensor_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BLdsTile::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor =
|
||||
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using BWarpDstr = typename WG::BWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using BWarpTensor = typename WG::BWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor_tmp.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
__device__ void operator()(CBlockTensor& c_block_tensor,
|
||||
@@ -38,6 +283,8 @@ struct BlockGemmARegBSmemCRegV1
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
@@ -46,7 +293,7 @@ struct BlockGemmARegBSmemCRegV1
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
@@ -180,6 +427,8 @@ struct BlockGemmARegBSmemCRegV1
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
@@ -188,7 +437,7 @@ struct BlockGemmARegBSmemCRegV1
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
|
||||
@@ -10,10 +10,25 @@ namespace ck_tile {
|
||||
|
||||
struct BlockGemmARegBSmemCRegV1DefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
template <typename Problem, index_t kM0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
if constexpr(kM0 == 64)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr(kM0 == 32)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1);
|
||||
}
|
||||
else if constexpr(kM0 == 128)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported configuration for warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -10,10 +10,25 @@ namespace ck_tile {
|
||||
|
||||
struct BlockGemmARegBSmemCRegV1K8Policy
|
||||
{
|
||||
template <typename Problem>
|
||||
template <typename Problem, index_t kM0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
if constexpr(kM0 == 64)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr(kM0 == 32)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1);
|
||||
}
|
||||
else if constexpr(kM0 == 128)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported configuration for warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -13,16 +13,13 @@ namespace ck_tile {
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, index_t kHeadDim>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
Problem,
|
||||
BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy<kHeadDim>>
|
||||
template <typename Problem, typename Policy>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy<kHeadDim>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
@@ -58,8 +55,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
|
||||
"wrong!");
|
||||
|
||||
ignore = a_element_func;
|
||||
@@ -135,6 +131,8 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
@@ -159,6 +157,9 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
block_gemm.HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -218,6 +219,9 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
ignore = b_element_func;
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
|
||||
// A tile in Reg,blockTensor
|
||||
// This tensor distribution used to construct both distributed tensor for local buffer store
|
||||
// and read. without buffer address info
|
||||
@@ -257,58 +261,90 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode()));
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(
|
||||
get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence<kMPerBlock, kKPerBlock>{}),
|
||||
b_lds_gemm_window)){};
|
||||
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
#if !defined(TOY_FA_FWD_OPT)
|
||||
static_for<0, k_loops, 1>{}([&](auto i_k0) {
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
block_sync_lds();
|
||||
});
|
||||
#else
|
||||
using BLdsTile = typename decltype(block_gemm)::BLdsTile;
|
||||
BLdsTile bWarpTile;
|
||||
|
||||
// Global read 0
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
// LDS write 0
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 1
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
block_sync_lds();
|
||||
|
||||
// LDS read 0
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
}
|
||||
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write 1
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 2
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
bWarpTile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
// LDS read 1
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
block_gemm.HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
bWarpTile);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
@@ -316,13 +352,15 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
bWarpTile);
|
||||
}
|
||||
|
||||
#endif
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
@@ -336,9 +374,9 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
[](const ADataType & a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
[](const BDataType & b) { return b; },
|
||||
a_reg_block_tensor_tmp,
|
||||
p_smem);
|
||||
}
|
||||
@@ -350,7 +388,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
{
|
||||
return operator()(
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
[](const BDataType & b) { return b; },
|
||||
a_reg_block_tensor_tmp,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
@@ -3,43 +3,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp"
|
||||
#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// NOTE: Assume A is K-Major
|
||||
struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy
|
||||
template <index_t AKDim_>
|
||||
struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeARegBlockDescriptor()
|
||||
{
|
||||
constexpr auto blockgemm = GetBlockGemm<Problem>();
|
||||
using BlockGemm = remove_cvref_t<decltype(blockgemm)>;
|
||||
|
||||
return policy_impl::make_a_reg_block_descriptor<Problem, BlockGemm>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
return policy_impl::make_b_lds_block_descriptor_3d_pad<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
constexpr auto blockgemm = GetBlockGemm<Problem>();
|
||||
using BlockGemm = remove_cvref_t<decltype(blockgemm)>;
|
||||
|
||||
return policy_impl::make_a_dram_tile_distribution_skip_lds<Problem, BlockGemm>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
return policy_impl::make_b_dram_tile_distribution<Problem>();
|
||||
}
|
||||
static constexpr index_t AKDim = AKDim_;
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto GetBlockGemm()
|
||||
@@ -48,13 +20,6 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy
|
||||
|
||||
return BlockGemmARegBSmemCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t AKDim_>
|
||||
struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
: BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy
|
||||
{
|
||||
static constexpr index_t AKDim = AKDim_;
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeARegBlockDescriptor()
|
||||
@@ -62,11 +27,13 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
constexpr auto blockgemm = GetBlockGemm<Problem>();
|
||||
using BlockGemm = remove_cvref_t<decltype(blockgemm)>;
|
||||
|
||||
static_assert((Problem::BlockGemmShape::kM == Problem::BlockGemmShape::kN), "wrong!");
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = AKDim;
|
||||
|
||||
constexpr auto config =
|
||||
BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem, kMPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
@@ -91,6 +58,87 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
|
||||
return a_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
return MakeARegBlockDescriptor<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
|
||||
number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * NLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace policy_impl {
|
||||
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto make_a_lds_block_descriptor_3d_pad()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
transform_tensor_descriptor(a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / 8>{}, number<kNPerBlock>{}, number<8>{}),
|
||||
make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc =
|
||||
transform_tensor_descriptor(b_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kNPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem, typename BlockGemm>
|
||||
__host__ __device__ static constexpr auto make_a_reg_block_descriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
return a_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto make_a_dram_tile_distribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem, typename BlockGemm>
|
||||
__host__ __device__ static constexpr auto make_a_dram_tile_distribution_skip_lds()
|
||||
{
|
||||
constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K2 =
|
||||
WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; // WG::WarpGemmAttribute::Impl::kABKPerLane;
|
||||
// // 16 / sizeof(ADataType);
|
||||
constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t K0 = kKPerBlock / (K1 * K2);
|
||||
|
||||
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t M1 = MWarp;
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto make_b_dram_tile_distribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto get_block_gemm()
|
||||
{
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegDefaultPolicy;
|
||||
|
||||
return BlockGemmASmemBSmemCReg<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
} // namespace policy_impl
|
||||
} // namespace ck_tile
|
||||
@@ -29,32 +29,28 @@ int main(int argc, char* argv[])
|
||||
using OaccDataType = float;
|
||||
using ODataType = ck_tile::half_t;
|
||||
|
||||
ck_tile::index_t Batch = 64; // Batch Number * Head Number
|
||||
ck_tile::index_t M0 = 4096; // SequenceLengthQ
|
||||
ck_tile::index_t N0 = 4096; // SequencelengthK
|
||||
ck_tile::index_t K0 = 128; // HeadDim
|
||||
ck_tile::index_t N1 = 128; // HeadDim
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t init_method = 1;
|
||||
[[maybe_unused]] ck_tile::index_t time_kernel = 0;
|
||||
ck_tile::index_t Batch = 64; // Batch Number * Head Number
|
||||
ck_tile::index_t M0 = 4096; // SequenceLengthQ
|
||||
ck_tile::index_t N0 = 4096; // SequencelengthK
|
||||
ck_tile::index_t K0 = 128; // HeadDim
|
||||
ck_tile::index_t N1 = 128; // HeadDim
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t init_method = 1;
|
||||
|
||||
if(argc == 4)
|
||||
if(argc == 3)
|
||||
{
|
||||
init_method = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
verification = std::stoi(argv[3]);
|
||||
verification = std::stoi(argv[2]);
|
||||
}
|
||||
|
||||
if(argc == 9)
|
||||
else if(argc == 8)
|
||||
{
|
||||
init_method = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
verification = std::stoi(argv[3]);
|
||||
Batch = std::stoi(argv[4]);
|
||||
M0 = std::stoi(argv[5]);
|
||||
N0 = std::stoi(argv[6]);
|
||||
K0 = std::stoi(argv[7]);
|
||||
N1 = std::stoi(argv[8]);
|
||||
verification = std::stoi(argv[2]);
|
||||
Batch = std::stoi(argv[3]);
|
||||
M0 = std::stoi(argv[4]);
|
||||
N0 = std::stoi(argv[5]);
|
||||
K0 = std::stoi(argv[6]);
|
||||
N1 = std::stoi(argv[7]);
|
||||
}
|
||||
|
||||
std::array<ck_tile::index_t, 3> q_lengths{Batch, M0, K0};
|
||||
|
||||
@@ -8,13 +8,69 @@
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "block_gemm_pipeline_problem.hpp"
|
||||
#include "block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "flash_attention_fwd_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
return [=](index_t block_1d_id) {
|
||||
constexpr index_t M01 = 4;
|
||||
constexpr index_t GroupNum = 8;
|
||||
|
||||
const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2;
|
||||
const auto update_M0 =
|
||||
((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum;
|
||||
|
||||
const auto xcd_id = block_1d_id % GroupNum;
|
||||
|
||||
const auto l_block_id = block_1d_id - (xcd_id % 2);
|
||||
|
||||
const auto ridn = GroupNum * M01 * (update_N0 / 2);
|
||||
const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn;
|
||||
const auto lu = (l_block_id % GroupNum) + rid * ridn;
|
||||
|
||||
const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01);
|
||||
const auto sub_M0_id = (l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum;
|
||||
|
||||
auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2);
|
||||
auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2);
|
||||
|
||||
const auto total_update_size = update_N0 * update_M0;
|
||||
|
||||
if(block_1d_id >= total_update_size)
|
||||
{
|
||||
auto x = (block_1d_id + 1) - total_update_size;
|
||||
auto rlen = N0 - update_N0;
|
||||
|
||||
auto rm = 0;
|
||||
auto rn = 0;
|
||||
if(rlen > 0)
|
||||
{
|
||||
rm = (x - 1) / rlen;
|
||||
rn = x % rlen;
|
||||
}
|
||||
|
||||
if(rlen > 0 and rm < M0)
|
||||
{
|
||||
n = rn + update_N0;
|
||||
m = rm;
|
||||
}
|
||||
else
|
||||
{
|
||||
x = x - rlen * M0;
|
||||
rm = (x - 1) / update_N0;
|
||||
rn = x % update_N0;
|
||||
n = rn;
|
||||
m = update_M0 + rm;
|
||||
}
|
||||
}
|
||||
return make_multi_index(m, n);
|
||||
};
|
||||
}
|
||||
|
||||
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
|
||||
// P[M0, N0] = Softmax(S[M0, N0])
|
||||
// O[M0, N1] = P[M0, N0] * V[N1, N0]
|
||||
@@ -53,25 +109,38 @@ struct FlashAttentionFwd
|
||||
const index_t BatchStrideV,
|
||||
const index_t BatchStrideO) const
|
||||
{
|
||||
// divide problem
|
||||
const index_t num_tile_m0 = M0 / kM0PerBlock;
|
||||
const index_t num_tile_n1 = N1 / kN1PerBlock;
|
||||
|
||||
const index_t id_block = get_block_id();
|
||||
|
||||
const index_t num_tile_m0 = integer_divide_ceil(M0, kM0PerBlock);
|
||||
const index_t num_tile_n1 = integer_divide_ceil(N1, kN1PerBlock);
|
||||
|
||||
#if defined(TOY_FA_FWD_OPT)
|
||||
#pragma message("Enable toy FA fwd opt")
|
||||
const auto block2tile = MakeBlock2TileMap(num_tile_m0, num_tile_n1);
|
||||
|
||||
const index_t id_tile_batch = id_block / num_tile_n1 / num_tile_m0;
|
||||
const auto id_tile = block2tile(id_block - id_tile_batch * num_tile_n1 * num_tile_m0);
|
||||
|
||||
const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch);
|
||||
const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) %
|
||||
num_tile_m0 * kM0PerBlock);
|
||||
const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) %
|
||||
num_tile_n1 * kN1PerBlock);
|
||||
|
||||
#else
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
|
||||
return make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [itmp, id_tile_n] = f(id_block, num_tile_n1);
|
||||
const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0);
|
||||
const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch);
|
||||
const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock);
|
||||
const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock);
|
||||
|
||||
const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch);
|
||||
const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock);
|
||||
const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock);
|
||||
#endif
|
||||
|
||||
const auto kernel_impl = FlashAttentionFwdImpl<QDataType,
|
||||
KDataType,
|
||||
|
||||
@@ -4,17 +4,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
#include "tile_gemm_shape.hpp"
|
||||
#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp"
|
||||
#include "block_gemm_pipeline_problem.hpp"
|
||||
#include "block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "tile_gemm_shape.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -65,23 +63,45 @@ struct FlashAttentionFwdImpl
|
||||
{
|
||||
constexpr index_t kNPerBlock = kN1PerBlock;
|
||||
constexpr index_t kKPerBlock = kK1PerBlock;
|
||||
constexpr index_t kPad = 1;
|
||||
// 2% faster than use kK1 = 8
|
||||
constexpr index_t kK1 = 4;
|
||||
constexpr index_t kKPack = 4;
|
||||
|
||||
constexpr auto dataTypeSize = sizeof(VDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / kKPerBlock / dataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / dataTypeSize);
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kK1>{}, number<kNPerBlock>{}, number<kK1>{}),
|
||||
make_tuple(number<(kNPerBlock + kPad) * kK1>{}, number<kK1>{}, number<1>{}),
|
||||
number<kK1>{},
|
||||
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
|
||||
number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kNPerBlock),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kK1>{}, number<kK1>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * NLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
@@ -132,6 +152,10 @@ struct FlashAttentionFwdImpl
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
// Block GEMM0 pipeline and Block GEMM1
|
||||
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
|
||||
constexpr auto gemm1 = BlockGemm1{};
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetStaticLdsSize()];
|
||||
|
||||
@@ -146,7 +170,10 @@ struct FlashAttentionFwdImpl
|
||||
v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{});
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram, make_tuple(number<kM0PerBlock>{}, number<kK0PerBlock>{}), {iM0, 0});
|
||||
q_dram,
|
||||
make_tuple(number<kM0PerBlock>{}, number<kK0PerBlock>{}),
|
||||
{iM0, 0},
|
||||
BlockGemm0Policy::template MakeADramTileDistribution<BlockGemm0Problem>());
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram, make_tuple(number<kN0PerBlock>{}, number<kK0PerBlock>{}), {0, 0});
|
||||
@@ -156,22 +183,32 @@ struct FlashAttentionFwdImpl
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{iN1, 0},
|
||||
MakeVDramTileDistribution());
|
||||
|
||||
// Q in Register
|
||||
auto q_reg_tensor = make_static_distributed_tensor<QDataType>(
|
||||
BlockGemm0Policy::template MakeARegBlockDescriptor<BlockGemm0Problem>());
|
||||
// Q in register
|
||||
auto q_reg_tensor = load_tile(q_dram_window);
|
||||
|
||||
// V LDS and LDS window
|
||||
// V LDS occupies the same LDS allocation Q/K LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(smem_ptr), MakeVLdsBlockDescriptor());
|
||||
|
||||
#if defined(TOY_FA_FWD_OPT)
|
||||
// V LDS tile window for store
|
||||
auto v_copy_lds_window =
|
||||
make_tile_window(v_lds,
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{0, 0},
|
||||
v_dram_window.get_tile_distribution());
|
||||
|
||||
// V LDS tile for block GEMM
|
||||
auto v_lds_gemm_window =
|
||||
make_tile_window(v_lds,
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode()));
|
||||
#else
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM0 pipeline and Block GEMM1
|
||||
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
|
||||
constexpr auto gemm1 = BlockGemm1{};
|
||||
#endif
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
@@ -209,22 +246,19 @@ struct FlashAttentionFwdImpl
|
||||
// loop over Column of S (J loop)
|
||||
index_t iN0 = 0;
|
||||
|
||||
// Cold Q_Reg_Cache
|
||||
s_acc = gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, smem_ptr);
|
||||
do
|
||||
{
|
||||
// Hot Q_Reg_Cache
|
||||
if(iN0 > 0)
|
||||
{
|
||||
s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr);
|
||||
}
|
||||
s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr);
|
||||
|
||||
// S{j}
|
||||
const auto s =
|
||||
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc);
|
||||
|
||||
#if defined(TOY_FA_FWD_OPT)
|
||||
// prefetch load v tile
|
||||
const auto v_prefetch = load_tile(v_dram_window);
|
||||
|
||||
auto v_prefetch = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
#endif
|
||||
// m_local = rowmax(S{j})
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s, sequence<1>{}, f_max, std::numeric_limits<SMPLComputeDataType>::lowest());
|
||||
@@ -274,10 +308,30 @@ struct FlashAttentionFwdImpl
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
store_tile(v_lds_window, v_prefetch);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
#if !defined(TOY_FA_FWD_OPT)
|
||||
// type cast Pcompute{j} into P{j}
|
||||
const auto p =
|
||||
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
|
||||
|
||||
// Oacc{j}
|
||||
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
|
||||
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
store_tile(v_lds_window, v);
|
||||
block_sync_lds();
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
});
|
||||
#else
|
||||
using VLdsTile = typename decltype(gemm1)::BLdsTile;
|
||||
VLdsTile vWarpTile;
|
||||
|
||||
// type cast Pcompute{j} into P{j}
|
||||
const auto p =
|
||||
@@ -288,29 +342,58 @@ struct FlashAttentionFwdImpl
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
v_prefetch = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
}
|
||||
if constexpr(k1_loops > 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
static_for<0, k1_loops - 2, 1>{}([&](auto i_k1) {
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write 1
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
|
||||
// Global read 2
|
||||
v_prefetch = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
v_lds_window);
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
store_tile(v_lds_window, v);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
gemm1.template HotLoopScheduler<8, 4>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 2) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (k1_loops - 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, kN0PerBlock>{}),
|
||||
v_lds_window);
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
}
|
||||
#endif
|
||||
// move tile windows
|
||||
move_tile_window(k_dram_window, {kN0PerBlock, 0});
|
||||
iN0 += kN0PerBlock;
|
||||
|
||||
26
example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt
Executable file → Normal file
26
example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/CMakeLists.txt
Executable file → Normal file
@@ -7,8 +7,8 @@ endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FLASH_ATTENTION_FWD_ENABLE_APIS}
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--api ${FLASH_ATTENTION_FWD_ENABLE_APIS}
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
@@ -21,21 +21,21 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/flash_attention_fwd_blobs.txt FLASH_ATT
|
||||
add_custom_command(
|
||||
OUTPUT ${FLASH_ATTENTION_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FLASH_ATTENTION_FWD_ENABLE_APIS}
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--api ${FLASH_ATTENTION_FWD_ENABLE_APIS}
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--gen_blobs
|
||||
)
|
||||
|
||||
set(EXAMPLE_REDUCE "codegen_basic_flash_attention_fwd")
|
||||
message("adding example ${EXAMPLE_REDUCE}")
|
||||
|
||||
add_executable(${EXAMPLE_REDUCE}
|
||||
EXCLUDE_FROM_ALL
|
||||
add_executable(${EXAMPLE_REDUCE}
|
||||
EXCLUDE_FROM_ALL
|
||||
flash_attention_fwd.cpp
|
||||
)
|
||||
|
||||
target_include_directories(${EXAMPLE_REDUCE}
|
||||
PRIVATE
|
||||
target_include_directories(${EXAMPLE_REDUCE}
|
||||
PRIVATE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
@@ -45,14 +45,14 @@ message("FLASH_ATTENTION_FWD_GEN_BLOBS = ${FLASH_ATTENTION_FWD_GEN_BLOBS}")
|
||||
|
||||
|
||||
set(EXAMPLE_REDUCE_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
)
|
||||
|
||||
target_compile_options(${EXAMPLE_REDUCE}
|
||||
PRIVATE
|
||||
target_compile_options(${EXAMPLE_REDUCE}
|
||||
PRIVATE
|
||||
${EXAMPLE_REDUCE_COMPILE_OPTIONS}
|
||||
)
|
||||
|
||||
|
||||
@@ -26,6 +26,251 @@ struct BlockGemmARegBSmemCRegV1
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
|
||||
// B block tile distribution for load from lds
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto config =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem, Problem::BlockGemmShape::kM>();
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t NIterPerWarp = Problem::BlockGemmShape::kN / (NWarp * WG::kN);
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
template <index_t VectorSizeB = 8, index_t SmemPack = 8>
|
||||
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
constexpr index_t MPerXDL = WG::kM;
|
||||
constexpr index_t NPerXDL = WG::kN;
|
||||
constexpr index_t KPerXDL = WG::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNumM = config.template get<1>();
|
||||
|
||||
constexpr index_t B_LDS_RW_Width = SmemPack;
|
||||
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (kBlockSize * VectorSizeB);
|
||||
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width);
|
||||
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * NPerBlock * KPerBlock / (kBlockSize * B_LDS_RW_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(kBlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
// B split schedule
|
||||
constexpr auto num_ds_read_inst_b = B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16
|
||||
? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
|
||||
|
||||
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
B_LDS_RW_Width * sizeof(BDataType) / kPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
|
||||
constexpr auto num_dsread_b_mfma =
|
||||
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
|
||||
// stage 1
|
||||
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_b_mfma);
|
||||
constexpr auto num_mfma_per_issue = num_mfma_stage1 / (num_buffer_load_inst_b);
|
||||
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
|
||||
constexpr auto num_mfma_per_dswrite_b =
|
||||
(num_mfma_per_issue - num_dswrite_per_issue_b * 2 >= 1) ? 2 : 1;
|
||||
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, num_mfma_per_dswrite_b, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008,
|
||||
num_mfma_per_issue - num_mfma_per_dswrite_b *
|
||||
num_dswrite_per_issue_b,
|
||||
0); // MFMA
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100,
|
||||
num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
|
||||
ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp>
|
||||
__device__ void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BLdsTile& b_block_tensor_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BLdsTile::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
|
||||
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor =
|
||||
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using BWarpDstr = typename WG::BWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using BWarpTensor = typename WG::BWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
static constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor_tmp.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
__device__ void operator()(CBlockTensor& c_block_tensor,
|
||||
@@ -38,6 +283,8 @@ struct BlockGemmARegBSmemCRegV1
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
@@ -46,7 +293,7 @@ struct BlockGemmARegBSmemCRegV1
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
@@ -180,6 +427,8 @@ struct BlockGemmARegBSmemCRegV1
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
@@ -188,7 +437,7 @@ struct BlockGemmARegBSmemCRegV1
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
|
||||
@@ -10,10 +10,25 @@ namespace ck_tile {
|
||||
|
||||
struct BlockGemmARegBSmemCRegV1DefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
template <typename Problem, index_t kM0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
if constexpr(kM0 == 64)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr(kM0 == 32)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1);
|
||||
}
|
||||
else if constexpr(kM0 == 128)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported configuration for warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -10,10 +10,25 @@ namespace ck_tile {
|
||||
|
||||
struct BlockGemmARegBSmemCRegV1K8Policy
|
||||
{
|
||||
template <typename Problem>
|
||||
template <typename Problem, index_t kM0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
if constexpr(kM0 == 64)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr(kM0 == 32)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1);
|
||||
}
|
||||
else if constexpr(kM0 == 128)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported configuration for warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -13,16 +13,13 @@ namespace ck_tile {
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, index_t kHeadDim>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
Problem,
|
||||
BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy<kHeadDim>>
|
||||
template <typename Problem, typename Policy>
|
||||
struct BlockGemmPipelineAGmemBGmemCReg
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
using Policy = BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy<kHeadDim>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
@@ -58,8 +55,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
|
||||
"wrong!");
|
||||
|
||||
ignore = a_element_func;
|
||||
@@ -135,6 +131,8 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
@@ -159,6 +157,9 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
|
||||
block_gemm.HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -218,6 +219,9 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
ignore = b_element_func;
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
|
||||
// A tile in Reg,blockTensor
|
||||
// This tensor distribution used to construct both distributed tensor for local buffer store
|
||||
// and read. without buffer address info
|
||||
@@ -257,58 +261,90 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto block_gemm = Policy::template GetBlockGemm<Problem>();
|
||||
b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode()));
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = decltype(block_gemm(
|
||||
get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence<kMPerBlock, kKPerBlock>{}),
|
||||
b_lds_gemm_window)){};
|
||||
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
#if !defined(TOY_FA_FWD_OPT)
|
||||
static_for<0, k_loops, 1>{}([&](auto i_k0) {
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
block_sync_lds();
|
||||
});
|
||||
#else
|
||||
using BLdsTile = typename decltype(block_gemm)::BLdsTile;
|
||||
BLdsTile bWarpTile;
|
||||
|
||||
// Global read 0
|
||||
auto b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
// LDS write 0
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 1
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
block_sync_lds();
|
||||
|
||||
// LDS read 0
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
}
|
||||
|
||||
if constexpr(k_loops > 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
static_for<0, k_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write 1
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
|
||||
// Global read 2
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
bWarpTile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
// LDS read 1
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
b_block_tile = load_tile(b_copy_dram_window);
|
||||
block_gemm.HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
if constexpr(k_loops > 1)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
bWarpTile);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
@@ -316,13 +352,15 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
bWarpTile = load_tile(b_lds_gemm_window);
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
bWarpTile);
|
||||
}
|
||||
|
||||
#endif
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
@@ -336,9 +374,9 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
[](const ADataType & a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
[](const BDataType & b) { return b; },
|
||||
a_reg_block_tensor_tmp,
|
||||
p_smem);
|
||||
}
|
||||
@@ -350,7 +388,7 @@ struct BlockGemmPipelineAGmemBGmemCReg<
|
||||
{
|
||||
return operator()(
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
[](const BDataType & b) { return b; },
|
||||
a_reg_block_tensor_tmp,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
@@ -3,43 +3,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp"
|
||||
#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// NOTE: Assume A is K-Major
|
||||
struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy
|
||||
template <index_t AKDim_>
|
||||
struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeARegBlockDescriptor()
|
||||
{
|
||||
constexpr auto blockgemm = GetBlockGemm<Problem>();
|
||||
using BlockGemm = remove_cvref_t<decltype(blockgemm)>;
|
||||
|
||||
return policy_impl::make_a_reg_block_descriptor<Problem, BlockGemm>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
return policy_impl::make_b_lds_block_descriptor_3d_pad<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
constexpr auto blockgemm = GetBlockGemm<Problem>();
|
||||
using BlockGemm = remove_cvref_t<decltype(blockgemm)>;
|
||||
|
||||
return policy_impl::make_a_dram_tile_distribution_skip_lds<Problem, BlockGemm>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
return policy_impl::make_b_dram_tile_distribution<Problem>();
|
||||
}
|
||||
static constexpr index_t AKDim = AKDim_;
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto GetBlockGemm()
|
||||
@@ -48,13 +20,6 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy
|
||||
|
||||
return BlockGemmARegBSmemCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t AKDim_>
|
||||
struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
: BlockGemmPipelineAGmemBGmemCRegSkipALdsPolicy
|
||||
{
|
||||
static constexpr index_t AKDim = AKDim_;
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeARegBlockDescriptor()
|
||||
@@ -62,11 +27,13 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
constexpr auto blockgemm = GetBlockGemm<Problem>();
|
||||
using BlockGemm = remove_cvref_t<decltype(blockgemm)>;
|
||||
|
||||
static_assert((Problem::BlockGemmShape::kM == Problem::BlockGemmShape::kN), "wrong!");
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = AKDim;
|
||||
|
||||
constexpr auto config =
|
||||
BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem, kMPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
@@ -91,6 +58,87 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
|
||||
return a_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
return MakeARegBlockDescriptor<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t kKPack = 8;
|
||||
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
|
||||
number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * NLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace policy_impl {
|
||||
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto make_a_lds_block_descriptor_3d_pad()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
transform_tensor_descriptor(a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto make_b_lds_block_descriptor_3d_pad()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / 8>{}, number<kNPerBlock>{}, number<8>{}),
|
||||
make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc =
|
||||
transform_tensor_descriptor(b_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kNPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem, typename BlockGemm>
|
||||
__host__ __device__ static constexpr auto make_a_reg_block_descriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
constexpr index_t NWarp = config.template get<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
return a_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto make_a_dram_tile_distribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem, typename BlockGemm>
|
||||
__host__ __device__ static constexpr auto make_a_dram_tile_distribution_skip_lds()
|
||||
{
|
||||
constexpr auto config = BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K2 =
|
||||
WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; // WG::WarpGemmAttribute::Impl::kABKPerLane;
|
||||
// // 16 / sizeof(ADataType);
|
||||
constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t K0 = kKPerBlock / (K1 * K2);
|
||||
|
||||
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t M1 = MWarp;
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
sequence<2, 1, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto make_b_dram_tile_distribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(BDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto get_block_gemm()
|
||||
{
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegDefaultPolicy;
|
||||
|
||||
return BlockGemmASmemBSmemCReg<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
} // namespace policy_impl
|
||||
} // namespace ck_tile
|
||||
@@ -29,32 +29,28 @@ int main(int argc, char* argv[])
|
||||
using OaccDataType = float;
|
||||
using ODataType = ck_tile::half_t;
|
||||
|
||||
ck_tile::index_t Batch = 64; // Batch Number * Head Number
|
||||
ck_tile::index_t M0 = 4096; // SequenceLengthQ
|
||||
ck_tile::index_t N0 = 4096; // SequencelengthK
|
||||
ck_tile::index_t K0 = 128; // HeadDim
|
||||
ck_tile::index_t N1 = 128; // HeadDim
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t init_method = 1;
|
||||
[[maybe_unused]] ck_tile::index_t time_kernel = 0;
|
||||
ck_tile::index_t Batch = 64; // Batch Number * Head Number
|
||||
ck_tile::index_t M0 = 4096; // SequenceLengthQ
|
||||
ck_tile::index_t N0 = 4096; // SequencelengthK
|
||||
ck_tile::index_t K0 = 128; // HeadDim
|
||||
ck_tile::index_t N1 = 128; // HeadDim
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t init_method = 1;
|
||||
|
||||
if(argc == 4)
|
||||
if(argc == 3)
|
||||
{
|
||||
init_method = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
verification = std::stoi(argv[3]);
|
||||
verification = std::stoi(argv[2]);
|
||||
}
|
||||
|
||||
if(argc == 9)
|
||||
else if(argc == 8)
|
||||
{
|
||||
init_method = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
verification = std::stoi(argv[3]);
|
||||
Batch = std::stoi(argv[4]);
|
||||
M0 = std::stoi(argv[5]);
|
||||
N0 = std::stoi(argv[6]);
|
||||
K0 = std::stoi(argv[7]);
|
||||
N1 = std::stoi(argv[8]);
|
||||
verification = std::stoi(argv[2]);
|
||||
Batch = std::stoi(argv[3]);
|
||||
M0 = std::stoi(argv[4]);
|
||||
N0 = std::stoi(argv[5]);
|
||||
K0 = std::stoi(argv[6]);
|
||||
N1 = std::stoi(argv[7]);
|
||||
}
|
||||
|
||||
std::array<ck_tile::index_t, 3> q_lengths{Batch, M0, K0};
|
||||
|
||||
@@ -9,13 +9,69 @@
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "block_gemm_pipeline_problem.hpp"
|
||||
#include "block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "flash_attention_fwd_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
{
|
||||
return [=](index_t block_1d_id) {
|
||||
constexpr index_t M01 = 4;
|
||||
constexpr index_t GroupNum = 8;
|
||||
|
||||
const auto update_N0 = ((((N0 / 2) * 2) / 2) / M01) * M01 * 2;
|
||||
const auto update_M0 =
|
||||
((M0 / (GroupNum / 2)) * (GroupNum / 2)) / GroupNum / M01 * M01 * GroupNum;
|
||||
|
||||
const auto xcd_id = block_1d_id % GroupNum;
|
||||
|
||||
const auto l_block_id = block_1d_id - (xcd_id % 2);
|
||||
|
||||
const auto ridn = GroupNum * M01 * (update_N0 / 2);
|
||||
const auto rid = (l_block_id - (l_block_id % GroupNum)) / ridn;
|
||||
const auto lu = (l_block_id % GroupNum) + rid * ridn;
|
||||
|
||||
const auto sub_N0_id = (l_block_id - lu) / (GroupNum * M01);
|
||||
const auto sub_M0_id = (l_block_id - (sub_N0_id * (GroupNum * M01) + lu)) / GroupNum;
|
||||
|
||||
auto n = sub_N0_id + (xcd_id % 2) * (update_N0 / 2);
|
||||
auto m = rid * M01 + sub_M0_id + (update_M0 / (GroupNum / 2)) * (xcd_id / 2);
|
||||
|
||||
const auto total_update_size = update_N0 * update_M0;
|
||||
|
||||
if(block_1d_id >= total_update_size)
|
||||
{
|
||||
auto x = (block_1d_id + 1) - total_update_size;
|
||||
auto rlen = N0 - update_N0;
|
||||
|
||||
auto rm = 0;
|
||||
auto rn = 0;
|
||||
if(rlen > 0)
|
||||
{
|
||||
rm = (x - 1) / rlen;
|
||||
rn = x % rlen;
|
||||
}
|
||||
|
||||
if(rlen > 0 and rm < M0)
|
||||
{
|
||||
n = rn + update_N0;
|
||||
m = rm;
|
||||
}
|
||||
else
|
||||
{
|
||||
x = x - rlen * M0;
|
||||
rm = (x - 1) / update_N0;
|
||||
rn = x % update_N0;
|
||||
n = rn;
|
||||
m = update_M0 + rm;
|
||||
}
|
||||
}
|
||||
return make_multi_index(m, n);
|
||||
};
|
||||
}
|
||||
|
||||
template <typename QDataType, typename KDataType, typename VDataType, typename ODataType>
|
||||
struct FlashAttnArgs
|
||||
{
|
||||
@@ -83,25 +139,21 @@ struct FlashAttentionFwd
|
||||
const index_t BatchStrideV,
|
||||
const index_t BatchStrideO) const
|
||||
{
|
||||
// divide problem
|
||||
const index_t num_tile_m0 = M0 / kM0PerBlock;
|
||||
const index_t num_tile_n1 = N1 / kN1PerBlock;
|
||||
|
||||
const index_t id_block = get_block_id();
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
const index_t num_tile_m0 = integer_divide_ceil(M0, kM0PerBlock);
|
||||
const index_t num_tile_n1 = integer_divide_ceil(N1, kN1PerBlock);
|
||||
|
||||
return make_tuple(quotient, modulus);
|
||||
};
|
||||
const auto block2tile = MakeBlock2TileMap(num_tile_m0, num_tile_n1);
|
||||
|
||||
const auto [itmp, id_tile_n] = f(id_block, num_tile_n1);
|
||||
const auto [id_tile_batch, id_tile_m] = f(itmp, num_tile_m0);
|
||||
const index_t id_tile_batch = id_block / num_tile_n1 / num_tile_m0;
|
||||
const auto id_tile = block2tile(id_block - id_tile_batch * num_tile_n1 * num_tile_m0);
|
||||
|
||||
const index_t iBatch = __builtin_amdgcn_readfirstlane(id_tile_batch);
|
||||
const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile_m * kM0PerBlock);
|
||||
const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile_n * kN1PerBlock);
|
||||
const index_t iM0 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<0>{}) %
|
||||
num_tile_m0 * kM0PerBlock);
|
||||
const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) %
|
||||
num_tile_n1 * kN1PerBlock);
|
||||
|
||||
const auto kernel_impl = FlashAttentionFwdImpl<QDataType,
|
||||
KDataType,
|
||||
@@ -136,169 +188,6 @@ struct FlashAttentionFwd
|
||||
}
|
||||
};
|
||||
|
||||
// // TODO: fwd_api.cpp
|
||||
// template <typename SaccDataType_,
|
||||
// typename SMPLComputeDataType_,
|
||||
// typename PDataType_,
|
||||
// typename OaccDataType_,
|
||||
// index_t kBlockSize_,
|
||||
// index_t kHeadDim_,
|
||||
// index_t kM0PerBlock_,
|
||||
// index_t kN0PerBlock_,
|
||||
// index_t kK0PerBlock_,
|
||||
// index_t kN1PerBlock_,
|
||||
// index_t kK1PerBlock_>
|
||||
// struct flash_attention_fwd_traits_
|
||||
// {
|
||||
// using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
|
||||
// using SMPLComputeDataType = ck_tile::remove_cvref_t<SMPLComputeDataType_>;
|
||||
// using PDataType = ck_tile::remove_cvref_t<PDataType_>;
|
||||
// using OaccDataType = ck_tile::remove_cvref_t<OaccDataType_>;
|
||||
|
||||
// static constexpr index_t kBlockSize = kBlockSize_;
|
||||
// static constexpr index_t kHeadDim = kHeadDim_;
|
||||
// static constexpr index_t kM0PerBlock = kM0PerBlock_;
|
||||
// static constexpr index_t kN0PerBlock = kN0PerBlock_;
|
||||
// static constexpr index_t kK0PerBlock = kK0PerBlock_;
|
||||
// static constexpr index_t kN1PerBlock = kN1PerBlock_;
|
||||
// static constexpr index_t kK1PerBlock = kK1PerBlock_;
|
||||
|
||||
// static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
// static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
|
||||
// static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
// };
|
||||
|
||||
// // TODO: fwd_api.cpp, fwd_common.cpp
|
||||
// template <typename SaccDataType,
|
||||
// typename SMPLComputeDataType,
|
||||
// typename PDataType,
|
||||
// typename OaccDataType,
|
||||
// index_t kBlockSize,
|
||||
// index_t kHeadDim,
|
||||
// index_t kM0PerBlock,
|
||||
// index_t kN0PerBlock,
|
||||
// index_t kK0PerBlock,
|
||||
// index_t kN1PerBlock,
|
||||
// index_t kK1PerBlock>
|
||||
// using traits_ = flash_attention_fwd_traits_<SaccDataType,
|
||||
// SMPLComputeDataType,
|
||||
// PDataType,
|
||||
// OaccDataType,
|
||||
// kBlockSize,
|
||||
// kHeadDim,
|
||||
// kM0PerBlock,
|
||||
// kN0PerBlock,
|
||||
// kK0PerBlock,
|
||||
// kN1PerBlock,
|
||||
// kK1PerBlock>;
|
||||
// // fw_api.cpp
|
||||
// // Note: this internal API only declare, not define here, otherwise will block `make -j`
|
||||
// template <typename QDataType,
|
||||
// typename KDataType,
|
||||
// typename VDataType,
|
||||
// typename ODataType,
|
||||
// typename Traits_>
|
||||
// float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
// const ck_tile::stream_config& stream_config);
|
||||
|
||||
// // TODO: fwd_common.cpp
|
||||
// template <typename QDataType,
|
||||
// typename KDataType,
|
||||
// typename VDataType,
|
||||
// typename ODataType,
|
||||
// typename Traits_>
|
||||
// float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
// const ck_tile::stream_config& stream_config) {
|
||||
// using SaccDataType = typename Traits_::SaccDataType;
|
||||
// using SMPLComputeDataType = typename Traits_::SMPLComputeDataType;
|
||||
// using PDataType = typename Traits_::PDataType;
|
||||
// using OaccDataType = typename Traits_::OaccDataType;
|
||||
|
||||
// index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock);
|
||||
|
||||
// std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
// return ck_tile::launch_kernel(stream_config,
|
||||
// ck_tile::make_kernel<Traits_::kBlockSize, Traits_::kBlockPerCu>(
|
||||
// ck_tile::FlashAttentionFwd<QDataType,
|
||||
// KDataType,
|
||||
// VDataType,
|
||||
// SaccDataType,
|
||||
// SMPLComputeDataType,
|
||||
// PDataType,
|
||||
// OaccDataType,
|
||||
// ODataType,
|
||||
// Traits_::kBlockSize,
|
||||
// Traits_::kHeadDim,
|
||||
// Traits_::kM0PerBlock,
|
||||
// Traits_::kN0PerBlock,
|
||||
// Traits_::kK0PerBlock,
|
||||
// Traits_::kN1PerBlock,
|
||||
// Traits_::kK1PerBlock>{},
|
||||
// kGridSize,
|
||||
// Traits_::kBlockSize,
|
||||
// 0,
|
||||
// a.q_ptr,
|
||||
// a.k_ptr,
|
||||
// a.v_ptr,
|
||||
// a.o_ptr,
|
||||
// a.M0,
|
||||
// a.N0,
|
||||
// a.K0,
|
||||
// a.N1,
|
||||
// a.Batch,
|
||||
// a.strideQ, // StrideQ
|
||||
// a.strideK, // StrideK
|
||||
// a.strideV, // StrideV
|
||||
// a.strideO, // StrideO
|
||||
// a.batchStrideQ, // BatchStrideQ
|
||||
// a.batchStrideK, // BatchStrideK
|
||||
// a.batchStrideV, // BatchStrideV
|
||||
// a.batchStrideO)); // BatchStrideO
|
||||
// }
|
||||
|
||||
// // TODO: change to only declare
|
||||
// // TODO: fwd_api.cpp
|
||||
// template <typename QDataType,
|
||||
// typename KDataType,
|
||||
// typename VDataType,
|
||||
// typename SaccDataType,
|
||||
// typename SMPLComputeDataType,
|
||||
// typename PDataType,
|
||||
// typename OaccDataType,
|
||||
// typename ODataType>
|
||||
// float flash_attention_fwd(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
// const ck_tile::stream_config& stream_config) {
|
||||
// constexpr ck_tile::index_t kM0PerBlock = 128;
|
||||
// constexpr ck_tile::index_t kN0PerBlock = 128;
|
||||
// constexpr ck_tile::index_t kK0PerBlock = 32;
|
||||
// constexpr ck_tile::index_t kN1PerBlock = 128;
|
||||
// constexpr ck_tile::index_t kK1PerBlock = 32;
|
||||
|
||||
// constexpr ck_tile::index_t kBlockSize = 256;
|
||||
// constexpr ck_tile::index_t kHeadDim = 128;
|
||||
|
||||
// return flash_attention_fwd_<QDataType,
|
||||
// KDataType,
|
||||
// VDataType,
|
||||
// ODataType,
|
||||
// traits_<SaccDataType,
|
||||
// SMPLComputeDataType,
|
||||
// PDataType,
|
||||
// OaccDataType,
|
||||
// kBlockSize,
|
||||
// kHeadDim,
|
||||
// kM0PerBlock,
|
||||
// kN0PerBlock,
|
||||
// kK0PerBlock,
|
||||
// kN1PerBlock,
|
||||
// kK1PerBlock>>
|
||||
// (a, stream_config);
|
||||
|
||||
// }
|
||||
|
||||
// TODO: change to only declare
|
||||
// TODO: fwd_api.cpp
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
|
||||
@@ -4,17 +4,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/core/tensor/tile_distribution.hpp"
|
||||
|
||||
#include "tile_gemm_shape.hpp"
|
||||
#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
|
||||
#include "block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp"
|
||||
#include "block_gemm_pipeline_problem.hpp"
|
||||
#include "block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "tile_gemm_shape.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -65,23 +63,45 @@ struct FlashAttentionFwdImpl
|
||||
{
|
||||
constexpr index_t kNPerBlock = kN1PerBlock;
|
||||
constexpr index_t kKPerBlock = kK1PerBlock;
|
||||
constexpr index_t kPad = 1;
|
||||
// 2% faster than use kK1 = 8
|
||||
constexpr index_t kK1 = 4;
|
||||
constexpr index_t kKPack = 4;
|
||||
|
||||
constexpr auto dataTypeSize = sizeof(VDataType);
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / kKPerBlock / dataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / dataTypeSize);
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / kK1>{}, number<kNPerBlock>{}, number<kK1>{}),
|
||||
make_tuple(number<(kNPerBlock + kPad) * kK1>{}, number<kK1>{}, number<1>{}),
|
||||
number<kK1>{},
|
||||
make_tuple(number<kKPerBlock / kKPack * NLdsLayer>{},
|
||||
number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPack>{}, number<kKPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kNPerBlock),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kK1>{}, number<kK1>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock / NLdsLayer>{},
|
||||
number<kKPerBlock / kKPack * NLdsLayer>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(
|
||||
make_merge_transform(
|
||||
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
@@ -132,6 +152,10 @@ struct FlashAttentionFwdImpl
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
// Block GEMM0 pipeline and Block GEMM1
|
||||
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
|
||||
constexpr auto gemm1 = BlockGemm1{};
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetStaticLdsSize()];
|
||||
|
||||
@@ -146,7 +170,10 @@ struct FlashAttentionFwdImpl
|
||||
v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{});
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram, make_tuple(number<kM0PerBlock>{}, number<kK0PerBlock>{}), {iM0, 0});
|
||||
q_dram,
|
||||
make_tuple(number<kM0PerBlock>{}, number<kK0PerBlock>{}),
|
||||
{iM0, 0},
|
||||
BlockGemm0Policy::template MakeADramTileDistribution<BlockGemm0Problem>());
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram, make_tuple(number<kN0PerBlock>{}, number<kK0PerBlock>{}), {0, 0});
|
||||
@@ -156,22 +183,32 @@ struct FlashAttentionFwdImpl
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{iN1, 0},
|
||||
MakeVDramTileDistribution());
|
||||
|
||||
// Q in Register
|
||||
auto q_reg_tensor = make_static_distributed_tensor<QDataType>(
|
||||
BlockGemm0Policy::template MakeARegBlockDescriptor<BlockGemm0Problem>());
|
||||
// Q in register
|
||||
auto q_reg_tensor = load_tile(q_dram_window);
|
||||
|
||||
// V LDS and LDS window
|
||||
// V LDS occupies the same LDS allocation Q/K LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<VDataType*>(smem_ptr), MakeVLdsBlockDescriptor());
|
||||
|
||||
#if defined(TOY_FA_FWD_OPT)
|
||||
// V LDS tile window for store
|
||||
auto v_copy_lds_window =
|
||||
make_tile_window(v_lds,
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{0, 0},
|
||||
v_dram_window.get_tile_distribution());
|
||||
|
||||
// V LDS tile for block GEMM
|
||||
auto v_lds_gemm_window =
|
||||
make_tile_window(v_lds,
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode()));
|
||||
#else
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM0 pipeline and Block GEMM1
|
||||
constexpr auto gemm0_pipeline = BlockGemm0Pipeline{};
|
||||
constexpr auto gemm1 = BlockGemm1{};
|
||||
#endif
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
@@ -209,22 +246,19 @@ struct FlashAttentionFwdImpl
|
||||
// loop over Column of S (J loop)
|
||||
index_t iN0 = 0;
|
||||
|
||||
// Cold Q_Reg_Cache
|
||||
s_acc = gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, smem_ptr);
|
||||
do
|
||||
{
|
||||
// Hot Q_Reg_Cache
|
||||
if(iN0 > 0)
|
||||
{
|
||||
s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr);
|
||||
}
|
||||
s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr);
|
||||
|
||||
// S{j}
|
||||
const auto s =
|
||||
tile_elementwise_in(type_convert<SMPLComputeDataType, SaccDataType>, s_acc);
|
||||
|
||||
#if defined(TOY_FA_FWD_OPT)
|
||||
// prefetch load v tile
|
||||
const auto v_prefetch = load_tile(v_dram_window);
|
||||
|
||||
auto v_prefetch = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
#endif
|
||||
// m_local = rowmax(S{j})
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s, sequence<1>{}, f_max, std::numeric_limits<SMPLComputeDataType>::lowest());
|
||||
@@ -274,10 +308,30 @@ struct FlashAttentionFwdImpl
|
||||
o_acc(i_j_idx) *= tmp;
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
store_tile(v_lds_window, v_prefetch);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
#if !defined(TOY_FA_FWD_OPT)
|
||||
// type cast Pcompute{j} into P{j}
|
||||
const auto p =
|
||||
tile_elementwise_in(type_convert<PDataType, SMPLComputeDataType>, p_compute);
|
||||
|
||||
// Oacc{j}
|
||||
constexpr index_t k1_loops = kN0PerBlock / kK1PerBlock;
|
||||
|
||||
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
store_tile(v_lds_window, v);
|
||||
block_sync_lds();
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
});
|
||||
#else
|
||||
using VLdsTile = typename decltype(gemm1)::BLdsTile;
|
||||
VLdsTile vWarpTile;
|
||||
|
||||
// type cast Pcompute{j} into P{j}
|
||||
const auto p =
|
||||
@@ -288,29 +342,58 @@ struct FlashAttentionFwdImpl
|
||||
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
v_prefetch = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
}
|
||||
if constexpr(k1_loops > 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
static_for<0, k1_loops - 2, 1>{}([&](auto i_k1) {
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write 1
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
|
||||
// Global read 2
|
||||
v_prefetch = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
v_lds_window);
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
store_tile(v_lds_window, v);
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
gemm1.template HotLoopScheduler<8, 4>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
});
|
||||
}
|
||||
// tail
|
||||
{
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 2) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (k1_loops - 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, kN0PerBlock>{}),
|
||||
v_lds_window);
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
}
|
||||
#endif
|
||||
// move tile windows
|
||||
move_tile_window(k_dram_window, {kN0PerBlock, 0});
|
||||
iN0 += kN0PerBlock;
|
||||
|
||||
@@ -11,16 +11,8 @@ import itertools
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
|
||||
# def get_if_str(idx, total, last_else=True):
|
||||
# if idx == 0:
|
||||
# return 'if'
|
||||
# elif idx < total - 1:
|
||||
# return 'else if'
|
||||
# else:
|
||||
# return 'else' if last_else else 'else if'
|
||||
|
||||
def get_if_str(size_, total, last_else=True):
|
||||
if size_ == "small":
|
||||
if size_ == "head_dim_256_seq_4096":
|
||||
return 'if'
|
||||
else:
|
||||
return 'else if'
|
||||
@@ -34,18 +26,18 @@ def BOOL_MAP(b_) -> str:
|
||||
|
||||
class FlashAttentionFwdCodegen:
|
||||
API_TRAITS_DEFINE = """
|
||||
|
||||
|
||||
template <typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
index_t kBlockSize_,
|
||||
index_t kHeadDim_,
|
||||
index_t kM0PerBlock_,
|
||||
index_t kN0PerBlock_,
|
||||
index_t kK0PerBlock_,
|
||||
index_t kN1PerBlock_,
|
||||
index_t kK1PerBlock_>
|
||||
index_t kBlockSize_ = 256,
|
||||
index_t kHeadDim_ = 128,
|
||||
index_t kM0PerBlock_ = 128,
|
||||
index_t kN0PerBlock_ = 128,
|
||||
index_t kK0PerBlock_ = 64,
|
||||
index_t kN1PerBlock_ = 128,
|
||||
index_t kK1PerBlock_ = 64>
|
||||
struct flash_attention_fwd_traits_
|
||||
{
|
||||
using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
|
||||
@@ -60,23 +52,23 @@ struct flash_attention_fwd_traits_
|
||||
static constexpr index_t kK0PerBlock = kK0PerBlock_;
|
||||
static constexpr index_t kN1PerBlock = kN1PerBlock_;
|
||||
static constexpr index_t kK1PerBlock = kK1PerBlock_;
|
||||
|
||||
|
||||
static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
|
||||
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size();
|
||||
static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
template <typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
typename OaccDataType,
|
||||
ck_tile::index_t kBlockSize,
|
||||
ck_tile::index_t kHeadDim,
|
||||
ck_tile::index_t kM0PerBlock,
|
||||
ck_tile::index_t kN0PerBlock,
|
||||
ck_tile::index_t kK0PerBlock,
|
||||
ck_tile::index_t kN1PerBlock,
|
||||
ck_tile::index_t kK1PerBlock>
|
||||
ck_tile::index_t kBlockSize = 256,
|
||||
ck_tile::index_t kHeadDim = 128,
|
||||
ck_tile::index_t kM0PerBlock = 128,
|
||||
ck_tile::index_t kN0PerBlock = 128,
|
||||
ck_tile::index_t kK0PerBlock = 64,
|
||||
ck_tile::index_t kN1PerBlock = 128,
|
||||
ck_tile::index_t kK1PerBlock = 64>
|
||||
using traits_ = flash_attention_fwd_traits_<SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
@@ -90,78 +82,6 @@ using traits_ = flash_attention_fwd_traits_<SaccDataType,
|
||||
kK1PerBlock>;
|
||||
"""
|
||||
|
||||
# API_COMMON_HEADER = """
|
||||
# // SPDX-License-Identifier: MIT
|
||||
# // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
# #include <ck_tile/core.hpp>
|
||||
# #include "flash_attention_fwd.hpp"
|
||||
# #include <iostream>
|
||||
|
||||
# #pragma once
|
||||
|
||||
# using S = ck_tile::stream_config;
|
||||
# using A = FlashAttnArgs;
|
||||
|
||||
# {F_traits_define}
|
||||
|
||||
# template <typename QDataType,
|
||||
# typename KDataType,
|
||||
# typename VDataType,
|
||||
# typename ODataType,
|
||||
# typename Traits_>
|
||||
# float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
# const ck_tile::stream_config& stream_config) {{
|
||||
# using SaccDataType = typename Traits_::SaccDataType;
|
||||
# using SMPLComputeDataType = typename Traits_::SMPLComputeDataType;
|
||||
# using PDataType = typename Traits_::PDataType;
|
||||
# using OaccDataType = typename Traits_::OaccDataType;
|
||||
|
||||
# index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock);
|
||||
|
||||
# if(stream_config.log_level_ > 0)
|
||||
# std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim << ">" << std::flush;
|
||||
|
||||
# return ck_tile::launch_kernel(stream_config,
|
||||
# ck_tile::make_kernel<Traits_::kBlockSize, Traits_::kBlockPerCu>(
|
||||
# ck_tile::FlashAttentionFwd<QDataType,
|
||||
# KDataType,
|
||||
# VDataType,
|
||||
# SaccDataType,
|
||||
# SMPLComputeDataType,
|
||||
# PDataType,
|
||||
# OaccDataType,
|
||||
# ODataType,
|
||||
# Traits_::kBlockSize,
|
||||
# Traits_::kHeadDim,
|
||||
# Traits_::kM0PerBlock,
|
||||
# Traits_::kN0PerBlock,
|
||||
# Traits_::kK0PerBlock,
|
||||
# Traits_::kN1PerBlock,
|
||||
# Traits_::kK1PerBlock>{{}},
|
||||
# kGridSize,
|
||||
# Traits_::kBlockSize,
|
||||
# 0,
|
||||
# a.q_ptr,
|
||||
# a.k_ptr,
|
||||
# a.v_ptr,
|
||||
# a.o_ptr,
|
||||
# a.M0,
|
||||
# a.N0,
|
||||
# a.K0,
|
||||
# a.N1,
|
||||
# a.Batch,
|
||||
# a.strideQ, // StrideQ
|
||||
# a.strideK, // StrideK
|
||||
# a.strideV, // StrideV
|
||||
# a.strideO, // StrideO
|
||||
# a.batchStrideQ, // BatchStrideQ
|
||||
# a.batchStrideK, // BatchStrideK
|
||||
# a.batchStrideV, // BatchStrideV
|
||||
# a.batchStrideO)); // BatchStrideO
|
||||
# }}
|
||||
# """
|
||||
|
||||
API_BASE = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
@@ -204,14 +124,6 @@ template float flash_attention_fwd<ck_tile::half_t, ck_tile::half_t, ck_tile::ha
|
||||
}}
|
||||
"""
|
||||
|
||||
# API_PER_DTYPE = """ {F_if}(std::is_same_v<QDataType, {F_q_type}> && std::is_same_v<KDataType, {F_k_type}> && std::is_same_v<VDataType, {F_v_type}> && std::is_same_v<ODataType, {F_o_type}>) {{
|
||||
# {F_per_size_case}
|
||||
# }}
|
||||
# """
|
||||
# API_PER_SIZE_CASE = """ {F_if} {F_SIZE_COND} {{
|
||||
# {F_inner_dispatch}
|
||||
# }}
|
||||
# """
|
||||
API_INNER_CASE = """ {F_if} {F_VEC_COND}
|
||||
r = flash_attention_fwd_<QDataType, KDataType, VDataType, ODataType, traits_<{F_trait_name}>>(a, stream_config);
|
||||
"""
|
||||
@@ -224,7 +136,7 @@ template float flash_attention_fwd<ck_tile::half_t, ck_tile::half_t, ck_tile::ha
|
||||
|
||||
namespace ck_tile {
|
||||
// clang-format off
|
||||
//
|
||||
//
|
||||
{F_instance_def}
|
||||
// clang-format on
|
||||
|
||||
@@ -315,18 +227,18 @@ namespace ck_tile {{
|
||||
#include "flash_attention_fwd.hpp"
|
||||
|
||||
namespace ck_tile {{
|
||||
|
||||
|
||||
template <typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
index_t kBlockSize_,
|
||||
index_t kHeadDim_,
|
||||
index_t kM0PerBlock_,
|
||||
index_t kN0PerBlock_,
|
||||
index_t kK0PerBlock_,
|
||||
index_t kN1PerBlock_,
|
||||
index_t kK1PerBlock_>
|
||||
index_t kBlockSize_ = 256,
|
||||
index_t kHeadDim_ = 128,
|
||||
index_t kM0PerBlock_ = 128,
|
||||
index_t kN0PerBlock_ = 128,
|
||||
index_t kK0PerBlock_ = 64,
|
||||
index_t kN1PerBlock_ = 128,
|
||||
index_t kK1PerBlock_ = 64>
|
||||
struct flash_attention_fwd_traits_
|
||||
{{
|
||||
using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
|
||||
@@ -341,13 +253,13 @@ struct flash_attention_fwd_traits_
|
||||
static constexpr index_t kK0PerBlock = kK0PerBlock_;
|
||||
static constexpr index_t kN1PerBlock = kN1PerBlock_;
|
||||
static constexpr index_t kK1PerBlock = kK1PerBlock_;
|
||||
|
||||
|
||||
static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
}};
|
||||
|
||||
|
||||
}};
|
||||
|
||||
|
||||
template <typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
@@ -379,11 +291,11 @@ template <typename QDataType,
|
||||
typename Traits_>
|
||||
float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
const ck_tile::stream_config& stream_config) {{
|
||||
using SaccDataType = typename Traits_::SaccDataType;
|
||||
using SMPLComputeDataType = typename Traits_::SMPLComputeDataType;
|
||||
using PDataType = typename Traits_::PDataType;
|
||||
using OaccDataType = typename Traits_::OaccDataType;
|
||||
|
||||
using SaccDataType = typename Traits_::SaccDataType;
|
||||
using SMPLComputeDataType = typename Traits_::SMPLComputeDataType;
|
||||
using PDataType = typename Traits_::PDataType;
|
||||
using OaccDataType = typename Traits_::OaccDataType;
|
||||
|
||||
index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock);
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
@@ -433,7 +345,7 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
|
||||
# Sort based on dtype
|
||||
t_dtype_dict = {}
|
||||
blobs = self.get_blobs(args)
|
||||
|
||||
|
||||
for blob in blobs:
|
||||
if blob.F_DataTypePair not in t_dtype_dict:
|
||||
t_dtype_dict[blob.F_DataTypePair] = {}
|
||||
@@ -445,47 +357,39 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
|
||||
for i_d, dtype_ in enumerate(t_dtype_dict):
|
||||
blob_per_t = t_dtype_dict[dtype_]
|
||||
size_str = ''
|
||||
|
||||
|
||||
for i_size, size_ in enumerate(blob_per_t):
|
||||
blob_per_size = blob_per_t[size_]
|
||||
inner_str = ""
|
||||
|
||||
|
||||
for i_b, b_ in enumerate(blob_per_size):
|
||||
for i_ins, ins in enumerate(b_.instance_list):
|
||||
idx_in_size = i_b * len(b_.instance_list) + i_ins
|
||||
len_in_size = sum(len(b.instance_list) for b in blob_per_size)
|
||||
|
||||
|
||||
size_cond = ""
|
||||
if size_ == "small":
|
||||
size_cond = "(a.M0 < 2048 && a.N0 < 2048)"
|
||||
elif size_ == "medium":
|
||||
size_cond = "(a.M0 >= 2048 && a.N0 >= 2048 && a.M0 < 4096 && a.N0 < 4096)"
|
||||
else: # large
|
||||
size_cond = "(a.M0 >= 4096 || a.N0 >= 4096)"
|
||||
|
||||
if size_ == "head_dim_256_seq_4096":
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 256 && a.N1 == 256)"
|
||||
elif size_ == "head_dim_128_seq_4096":
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 128 && a.N1 == 128)"
|
||||
elif size_ == "head_dim_64_seq_4096":
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 64 && a.N1 == 64)"
|
||||
elif size_ == "head_dim_32_seq_4096":
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 32 && a.N1 == 32)"
|
||||
elif size_ == "head_dim_128_seq_2048":
|
||||
size_cond = "(a.M0 <= 2048 && a.N0 <= 2048 && a.M0 > 512 && a.N0 > 512 && a.K0 == 128 && a.N1 == 128)"
|
||||
elif size_ == "head_dim_128_seq_512":
|
||||
size_cond = "(a.M0 <= 512 && a.N0 <= 512 && a.K0 == 128 && a.N1 == 128)"
|
||||
else:
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 128 && a.N1 == 128)"
|
||||
|
||||
inner_str += self.API_INNER_CASE.format(
|
||||
# F_if=get_if_str(idx_in_size, len_in_size, False),
|
||||
F_if=get_if_str(size_, len_in_size, False),
|
||||
F_VEC_COND=size_cond,
|
||||
F_trait_name=ins.trait_name
|
||||
)
|
||||
|
||||
# size_str += self.API_PER_SIZE_CASE.format(
|
||||
# F_if=get_if_str(i_size, len(blob_per_t)),
|
||||
# F_SIZE_COND=size_cond,
|
||||
# F_inner_dispatch=inner_str
|
||||
# )
|
||||
size_str += inner_str
|
||||
|
||||
# q_type, k_type, v_type, o_type = dtype_.split(',')
|
||||
# d_str += self.API_PER_DTYPE.format(
|
||||
# F_if=get_if_str(i_d, len(t_dtype_dict)),
|
||||
# F_q_type=DATA_TYPE_MAP[q_type],
|
||||
# F_k_type=DATA_TYPE_MAP[k_type],
|
||||
# F_v_type=DATA_TYPE_MAP[v_type],
|
||||
# F_o_type=DATA_TYPE_MAP[o_type],
|
||||
# F_per_size_case=size_str
|
||||
# )
|
||||
d_str += size_str
|
||||
|
||||
api_base = self.API_BASE.format(
|
||||
@@ -500,18 +404,24 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
|
||||
|
||||
# Define kernel configurations for different size categories
|
||||
trait_dict = {
|
||||
"small": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 128, 128, 128, 128, 32, 128, 32),
|
||||
# h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 64, 32, 64, 32),
|
||||
"head_dim_256_seq_4096": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 256, 128, 128, 64, 128, 64),
|
||||
],
|
||||
"medium": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 128, 128, 128, 128, 32, 128, 32),
|
||||
# h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 256, 128, 32, 128, 32),
|
||||
"head_dim_128_seq_4096": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 64, 128, 64),
|
||||
],
|
||||
"head_dim_64_seq_4096": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 64, 64, 64),
|
||||
],
|
||||
"head_dim_32_seq_4096": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 128, 32, 32, 32, 32, 32, 32),
|
||||
],
|
||||
"head_dim_128_seq_2048": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 64, 128, 64),
|
||||
],
|
||||
"head_dim_128_seq_512": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 128, 128, 128),
|
||||
],
|
||||
"large": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
|
||||
# h_traits('fp32', 'fp32', 'fp32', 'fp32', 512, 128, 256, 256, 32, 256, 32),
|
||||
]
|
||||
}
|
||||
|
||||
# Toy example only support fp16
|
||||
@@ -534,16 +444,16 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
|
||||
new_t.F_PDataType = q_type
|
||||
new_t.F_OaccDataType = 'fp32' # output accumulation in fp32
|
||||
current_traits.append(new_t)
|
||||
|
||||
|
||||
total_blob.append(h_instance(dtype_pair, size_category, current_traits))
|
||||
|
||||
|
||||
return total_blob
|
||||
|
||||
def list_blobs(self, args) -> None:
|
||||
w_p = Path(self.working_path)
|
||||
list_p = w_p / 'flash_attention_fwd_blobs.txt'
|
||||
blobs = self.get_blobs(args)
|
||||
|
||||
|
||||
with list_p.open('w') as list_f:
|
||||
# API related files
|
||||
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
|
||||
@@ -557,11 +467,11 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
|
||||
w_str = self.content_api(args)
|
||||
(w_p / (self.name_api + ".cpp")).write_text(w_str)
|
||||
(w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header)
|
||||
|
||||
|
||||
blobs = self.get_blobs(args)
|
||||
for b in blobs:
|
||||
(w_p / (b.name + ".cpp")).write_text(b.content)
|
||||
|
||||
|
||||
def list_blobs(args):
|
||||
api_list = args.api.split(',')
|
||||
for api in api_list:
|
||||
|
||||
@@ -51,5 +51,5 @@ make -j basic_flash_attention_fwd
|
||||
|
||||
### **Flash Attention Forward Example**
|
||||
```sh
|
||||
./bin/basic_flash_attention_fwd 1 0 1
|
||||
./bin/basic_flash_attention_fwd 1 1
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user