mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
[GEMM] Add define macro for unused a/b blk window
This commit is contained in:
@@ -111,8 +111,8 @@ struct BlockGemmASmemBSmemCReg
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockWindowTmp& a_block_window_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
[[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> &&
|
||||
@@ -127,14 +127,11 @@ struct BlockGemmASmemBSmemCReg
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
// constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
// using WarpGemm = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
#if !defined(ENABLE_PREFETCH)
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
@@ -142,7 +139,7 @@ struct BlockGemmASmemBSmemCReg
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// construct A-warp-window
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
@@ -158,13 +155,12 @@ struct BlockGemmASmemBSmemCReg
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// construct B-warp-window
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
@@ -180,16 +176,16 @@ struct BlockGemmASmemBSmemCReg
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
#endif
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
#pragma message("local data share prefetch")
|
||||
@@ -200,7 +196,7 @@ struct BlockGemmASmemBSmemCReg
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
#endif
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data(
|
||||
@@ -209,17 +205,17 @@ struct BlockGemmASmemBSmemCReg
|
||||
#else
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
#endif
|
||||
// read C warp tensor from C block tensor
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
// Warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
@@ -231,8 +227,8 @@ struct BlockGemmASmemBSmemCReg
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockWindowTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockWindowTmp& a_block_window_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp,
|
||||
[[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> &&
|
||||
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType>,
|
||||
@@ -246,14 +242,11 @@ struct BlockGemmASmemBSmemCReg
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
// constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
// using WarpGemm = remove_cvref_t<decltype(config.template get(number<0>{}))>;
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
#if !defined(ENABLE_PREFETCH)
|
||||
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
@@ -261,7 +254,7 @@ struct BlockGemmASmemBSmemCReg
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
// construct A-warp-window
|
||||
// Construct A-warp-window
|
||||
auto a_warp_window_tmp = make_tile_window(
|
||||
a_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kM>{}, number<WarpGemm::kK>{}),
|
||||
@@ -277,13 +270,12 @@ struct BlockGemmASmemBSmemCReg
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// construct B-warp-window
|
||||
// Construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WarpGemm::kN>{}, number<WarpGemm::kK>{}),
|
||||
@@ -299,11 +291,11 @@ struct BlockGemmASmemBSmemCReg
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
#endif
|
||||
|
||||
static_assert(std::is_same_v<CDataType, typename WarpGemm::CDataType>, "wrong!");
|
||||
|
||||
@@ -323,10 +315,10 @@ struct BlockGemmASmemBSmemCReg
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
// hot loop:
|
||||
// Hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
// Read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
a_warp_tensor.get_thread_buffer() = aWarpTile.get_y_sliced_thread_data(
|
||||
@@ -336,7 +328,7 @@ struct BlockGemmASmemBSmemCReg
|
||||
a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
#endif
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
// Read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
b_warp_tensor.get_thread_buffer() = bWarpTile.get_y_sliced_thread_data(
|
||||
@@ -345,10 +337,10 @@ struct BlockGemmASmemBSmemCReg
|
||||
#else
|
||||
b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
#endif
|
||||
// read C warp tensor from C block tensor
|
||||
// Read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
// warp GEMM
|
||||
// Warp GEMM
|
||||
if constexpr(KIterPerWarp == 0)
|
||||
{
|
||||
// c = a * b
|
||||
@@ -364,7 +356,7 @@ struct BlockGemmASmemBSmemCReg
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
}
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
// Write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
|
||||
@@ -307,7 +307,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
// Gemm pipeline start
|
||||
|
||||
#if defined(ENABLE_PREFETCH)
|
||||
|
||||
#pragma message("global prefetch")
|
||||
// Initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
@@ -342,7 +342,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
// Main body
|
||||
if(num_loop > 2)
|
||||
{
|
||||
index_t i = 0;
|
||||
index_t iCounter = 0;
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
Reference in New Issue
Block a user