[CK_TILE] More fmha splitkv optimizations (#1588)

* Use pre-defined constants for readability

* Use vector write for o_acc tensor

* Remove no-longer used policy method

* Deprecate no-longer used policy/pipeline

* Specify gemm0/gemm1 block warps separately in codegen

* Fix wrong ps_idx creation logic

* Add single-warp block gemm

* Supoprt single-warp gemm0

* Make MakeCBlockTile() as static method

* Use MakeCBlockTile() to get underlying tile distribution

* Use kNumGemm1Warps to compute # threads for gemm1

* Put normal case in the if clause

* Refine fmha splitkv block mapping

* Refine & fix the lse_acc/o_acc layout

* Fix wrong LDS size for K tile

* Use kK0=64 for hdim=128,256 fmha splitkv kernels

* Use kK1=64 for hdim=32,64,128 fmha splitkv kernels

* Undo kK0/kK1 changes

* Use more reasonable GetAlignmentV() computation

* Using store_tile() in fmha splitkv kernel epilogue
This commit is contained in:
Po Yen Chen
2024-10-26 18:35:45 +08:00
committed by GitHub
parent 37f7afed1e
commit 54f0e6f4bb
22 changed files with 422 additions and 199 deletions

View File

@@ -157,7 +157,7 @@ struct BlockGemmARegBRegCRegV1
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;

View File

@@ -0,0 +1,237 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV1DefaultPolicy>
struct BlockGemmARegBSmemCRegOneWarpV1
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static_assert(kBlockSize == get_warp_size(), "Check failed!");
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
// KPerBlock == BlockGemmShape::kK,
// "wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
static_assert(MWarp == 1 && NWarp == 1, "Check failed!");
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
const index_t iNWarp = 0;
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto a_block_tensor =
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array<
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
NIterPerWarp>
b_warp_windows;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
#endif
// check C-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block tensor
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
static_assert(MWarp == 1 && NWarp == 1, "Check failed!");
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp>>,
tuple<>,
tuple<>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
static_assert(decltype(c_block_dstr_encode)::NDimP == 1, "Check failed!");
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
return c_block_tensor;
}
};
} // namespace ck_tile

View File

@@ -181,7 +181,7 @@ struct BlockGemmARegBSmemCRegV1
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;

View File

@@ -182,7 +182,7 @@ struct BlockGemmARegBSmemCRegV2
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;

View File

@@ -180,7 +180,7 @@ struct BlockGemmASmemBRegCRegV1
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;

View File

@@ -167,7 +167,7 @@ struct BlockGemmASmemBSmemCRegV1
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;