revert mostly back to original comp_async

This commit is contained in:
Sami Remes
2026-01-30 12:40:48 -05:00
parent 2cc0e3d019
commit b124a72ff5
2 changed files with 475 additions and 196 deletions

View File

@@ -298,38 +298,96 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
auto a_tile_windows = generate_tuple(
[&](auto idx) {
/// NOTE: flatmm style byte tensor approach:
// Create tile window with STORAGE dimensions to match LDS
// auto&& tensor_view_tmp = a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view();
// auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
// const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
// auto&& a_tensor_view = make_naive_tensor_view<address_space_enum::global>(
// static_cast<const uint8_t*>(byte_ptr),
// make_tuple(rows, cols / APackedSize),
// make_tuple(cols / APackedSize, 1),
// number<16>{},
// number<1>{});
// return make_tile_window(a_tensor_view,
// make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{}),
// [&]() {
// auto origin = a_dram_block_window_tmp[number<idx>{}].get_window_origin();
// if constexpr(is_a_col_major) {
// origin[0] = origin[0] / APackedSize; // Adjust K origin
// } else {
// origin[1] = origin[1] / APackedSize; // Adjust K origin
// }
// return origin;
// }(),
// Policy::template MakeADramTileDistribution<Problem>());
/// NOTE: re-use original tensor view but with adjusted origin and K/PackedSize
// return make_tile_window(
// a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
// make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{}),
// [&]() {
// auto origin = a_dram_block_window_tmp[number<idx>{}].get_window_origin();
// if constexpr(is_a_col_major) {
// origin[0] = origin[0] / APackedSize; // Adjust K origin
// } else {
// origin[1] = origin[1] / APackedSize; // Adjust K origin
// }
// return origin;
// }(),
// Policy::template MakeADramTileDistribution<Problem>());
/// NOTE: use original shapes
return make_tile_window(
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{}),
[&]() {
auto origin = a_dram_block_window_tmp[number<idx>{}].get_window_origin();
if constexpr(is_a_col_major) {
origin[0] = origin[0] / APackedSize; // Adjust K origin
} else {
origin[1] = origin[1] / APackedSize; // Adjust K origin
}
return origin;
}(),
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeADramTileDistribution<Problem>());
},
number<AsLayout::size()>{});
// B DRAM window(s) for load
auto b_tile_windows = generate_tuple(
[&](auto idx) {
/// NOTE: flatmm style byte tensor approach:
// Create tile window with STORAGE dimensions to match LDS
// auto&& tensor_view_tmp = b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view();
// auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
// const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
// auto&& b_tensor_view = make_naive_tensor_view<address_space_enum::global>(
// static_cast<const uint8_t*>(byte_ptr),
// make_tuple(rows, cols / BPackedSize),
// make_tuple(cols / BPackedSize, 1),
// number<16>{},
// number<1>{});
// return make_tile_window(b_tensor_view,
// make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{}),
// [&]() {
// auto origin = b_dram_block_window_tmp[number<idx>{}].get_window_origin();
// if constexpr(is_b_row_major) {
// origin[0] = origin[0] / BPackedSize; // Adjust K origin
// } else {
// origin[1] = origin[1] / BPackedSize; // Adjust K origin
// }
// return origin;
// }(),
// Policy::template MakeBDramTileDistribution<Problem>());
/// NOTE: re-use original tensor view but with adjusted origin and K/PackedSize
// return make_tile_window(
// b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
// make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{}),
// [&]() {
// auto origin = b_dram_block_window_tmp[number<idx>{}].get_window_origin();
// if constexpr(is_b_row_major) {
// origin[0] = origin[0] / BPackedSize; // Adjust K origin
// } else {
// origin[1] = origin[1] / BPackedSize; // Adjust K origin
// }
// return origin;
// }(),
// Policy::template MakeBDramTileDistribution<Problem>());
/// NOTE: use original shapes
return make_tile_window(
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{}),
[&]() {
auto origin = b_dram_block_window_tmp[number<idx>{}].get_window_origin();
if constexpr(is_b_row_major) {
origin[0] = origin[0] / BPackedSize; // Adjust K origin
} else {
origin[1] = origin[1] / BPackedSize; // Adjust K origin
}
return origin;
}(),
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
Policy::template MakeBDramTileDistribution<Problem>());
},
number<BsLayout::size()>{});
@@ -382,22 +440,41 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
// this pipeline has a pair of LDS buffers per logical tile
// TODO: check for packed size - are these blocks too big?
/// NOTE: flatmm style byte tensor approach:
// auto&& [a_lds_block0, b_lds_block0] = Base::template GetABLdsTensorViews<uint8_t, uint8_t>(p_smem_0);
// auto&& [a_lds_block1, b_lds_block1] = Base::template GetABLdsTensorViews<uint8_t, uint8_t>(p_smem_1);
/// NOTE: with original fp4 types:
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
// set up LDS tile shapes - always use STORAGE dimensions for K
/// NOTE: flatmm style byte tensor approach:
// constexpr auto a_lds_shape = []() {
// if constexpr(is_a_load_tr_v)
// return make_tuple(number<KPerBlock / APackedSize>{}, number<MPerBlock>{});
// else
// return make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{});
// }();
// constexpr auto b_lds_shape = []() {
// if constexpr(is_b_load_tr_v)
// return make_tuple(number<KPerBlock / BPackedSize>{}, number<NPerBlock>{});
// else
// return make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{});
// }();
/// NOTE: use original shapes
constexpr auto a_lds_shape = []() {
if constexpr(is_a_load_tr_v)
return make_tuple(number<KPerBlock / APackedSize>{}, number<MPerBlock>{});
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
else
return make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{});
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
}();
constexpr auto b_lds_shape = []() {
if constexpr(is_b_load_tr_v)
return make_tuple(number<KPerBlock / BPackedSize>{}, number<NPerBlock>{});
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
else
return make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{});
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
}();
// LDS tile windows for storing, one per LDS buffer
@@ -413,10 +490,16 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
/// NOTE: flatmm style way to calculate steps with packed size
// constexpr ADramTileWindowStep a_dram_tile_window_step =
// is_a_col_major ? make_array(KPerBlock / APackedSize, 0) : make_array(0, KPerBlock / APackedSize);
// constexpr BDramTileWindowStep b_dram_tile_window_step =
// is_b_row_major ? make_array(KPerBlock / BPackedSize, 0) : make_array(0, KPerBlock / BPackedSize);
/// NOTE: use original steps and assume that PackedSize is correctly applied elsewhere
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock / APackedSize, 0) : make_array(0, KPerBlock / APackedSize);
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock / BPackedSize, 0) : make_array(0, KPerBlock / BPackedSize);
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// read A(0), B(0) from DRAM to LDS window(0)
// and advance the DRAM windows
@@ -426,8 +509,13 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
// Initialize WarpGemm for MX scaling
using WarpGemm = typename remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>::WarpGemm;
using CWarpTensor = typename WarpGemm::CWarpTensor;
// using WarpGemm = typename remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>::WarpGemm;
// using CWarpTensor = typename WarpGemm::CWarpTensor;
// Initialize block gemm and C block tile
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
clear_tile(c_block_tile);
// read A(1), B(1) from DRAM to LDS window(1)
// and advance the DRAM windows
@@ -449,6 +537,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
ALdsTile a_block_tile0, a_block_tile1;
BLdsTile b_block_tile0, b_block_tile1;
// Some sanity checks on the LDS tile sizes
static_assert(sizeof(ALdsTile) == MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize) * NWarp / BlockSize, "ALdsTile size is wrong!");
static_assert(sizeof(BLdsTile) == NPerBlock * (KPerBlock * sizeof(BDataType) / BPackedSize) * MWarp / BlockSize, "BLdsTile size is wrong!");
static_assert(Policy::template GetSmemSizeA<Problem>() == MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize), "SmemSizeA size is wrong!");
@@ -496,36 +585,44 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
move_tile_window(scale_b_dram_window, {KPerBlock / ScaleBlockSize / KXdlPack, 0});
};
constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() {
if constexpr(is_a_load_tr_v)
return make_static_tile_distribution(
typename InputTileDistributionTraits<
typename decltype(ALdsTileDistr)::DstrEncode,
typename Problem::ADataType>::TransposedDstrEncode{});
else
return ALdsTileDistr;
}();
constexpr auto b_lds_input_tile_distr = [BLdsTileDistr]() {
if constexpr(is_b_load_tr_v)
return make_static_tile_distribution(
typename InputTileDistributionTraits<
typename decltype(BLdsTileDistr)::DstrEncode,
typename Problem::BDataType>::TransposedDstrEncode{});
else
return BLdsTileDistr;
}();
// constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() {
// if constexpr(is_a_load_tr_v)
// return make_static_tile_distribution(
// typename InputTileDistributionTraits<
// typename decltype(ALdsTileDistr)::DstrEncode,
// typename Problem::ADataType>::TransposedDstrEncode{});
// else
// return ALdsTileDistr;
// }();
// constexpr auto b_lds_input_tile_distr = [BLdsTileDistr]() {
// if constexpr(is_b_load_tr_v)
// return make_static_tile_distribution(
// typename InputTileDistributionTraits<
// typename decltype(BLdsTileDistr)::DstrEncode,
// typename Problem::BDataType>::TransposedDstrEncode{});
// else
// return BLdsTileDistr;
// }();
// LDS tile windows for reading;
// they share the data pointer with the LDS windows for storing
// but also associate with a distribution to produce a register tile when reading
auto a_lds_ld_window0 =
make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, ALdsTileDistr);
auto a_lds_ld_window1 =
make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, ALdsTileDistr);
auto b_lds_ld_window0 =
make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, BLdsTileDistr);
auto b_lds_ld_window1 =
make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, BLdsTileDistr);
// auto a_lds_ld_window0 =
// make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, Policy::template MakeMX_ALDSBytes_TileDistribution<Problem>());
// auto a_lds_ld_window1 =
// make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, Policy::template MakeMX_ALDSBytes_TileDistribution<Problem>());
// auto b_lds_ld_window0 =
// make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, Policy::template MakeMX_BLDSBytes_TileDistribution<Problem>());
// auto b_lds_ld_window1 =
// make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, Policy::template MakeMX_BLDSBytes_TileDistribution<Problem>());
static_assert(!(is_tile_window_linear_v<decltype(a_lds_ld_window0)>) &&
!(is_tile_window_linear_v<decltype(a_lds_ld_window1)>) &&
@@ -534,61 +631,62 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
"LDS windows must not be linear");
// Create warp-level C tensors (one per M/N iteration)
statically_indexed_array<statically_indexed_array<CWarpTensor, NIterPerWarp>, MIterPerWarp> c_warp_tensors;
// statically_indexed_array<statically_indexed_array<CWarpTensor, NIterPerWarp>, MIterPerWarp> c_warp_tensors;
// Initialize C tensors
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
clear_tile(c_warp_tensors(mIter)(nIter));
});
});
/// TODO: create CBlockTile with block_gemm.MakeCBlockTile()
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// clear_tile(c_warp_tensors(mIter)(nIter));
// });
// });
// Warp GEMM loop with MX scaling
auto warp_gemm_loop = [&](const auto& a_block_tile, const auto& b_block_tile, const auto& scale_a, const auto& scale_b) {
// Extract A/B values from block tiles to warp iteration structure
constexpr auto a_warp_y_lengths =
to_sequence(typename WarpGemm::AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<WarpGemm::AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_lengths =
to_sequence(typename WarpGemm::BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<WarpGemm::BWarpDstr::NDimY, 0>{};
// auto warp_gemm_loop = [&](const auto& a_block_tile, const auto& b_block_tile, const auto& scale_a, const auto& scale_b) {
// // Extract A/B values from block tiles to warp iteration structure
// constexpr auto a_warp_y_lengths =
// to_sequence(typename WarpGemm::AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<WarpGemm::AWarpDstr::NDimY, 0>{};
// constexpr auto b_warp_y_lengths =
// to_sequence(typename WarpGemm::BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<WarpGemm::BWarpDstr::NDimY, 0>{};
static_for<0, KIterPerWarp, 1>{}([&](auto k_iter) {
// Map k_iter to packed scale index and OpSel
constexpr index_t kScalePacked = (k_iter * KPerXdl) / (ScaleBlockSize * KXdlPack);
// constexpr index_t kScaleInPack = ((k_iter * KPerXdl) / ScaleBlockSize) % KXdlPack;
constexpr index_t kScaleInPack = k_iter;
// static_for<0, KIterPerWarp, 1>{}([&](auto k_iter) {
// // Map k_iter to packed scale index and OpSel
// constexpr index_t kScalePacked = (k_iter * KPerXdl) / (ScaleBlockSize * KXdlPack);
// // constexpr index_t kScaleInPack = ((k_iter * KPerXdl) / ScaleBlockSize) % KXdlPack;
// constexpr index_t kScaleInPack = k_iter;
static_for<0, MIterPerWarp, 1>{}([&](auto m_iter) {
constexpr auto OpSelA = kScaleInPack;
// static_for<0, MIterPerWarp, 1>{}([&](auto m_iter) {
// constexpr auto OpSelA = kScaleInPack;
// read A warp tensor from A block tensor
typename WarpGemm::AWarpTensor a_warp_tensor;
// // read A warp tensor from A block tensor
// typename WarpGemm::AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<m_iter, k_iter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// a_warp_tensor.get_thread_buffer() = a_block_tile.get_y_sliced_thread_data(
// merge_sequences(sequence<m_iter, k_iter>{}, a_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto n_iter) {
constexpr auto OpSelB = kScaleInPack;
// static_for<0, NIterPerWarp, 1>{}([&](auto n_iter) {
// constexpr auto OpSelB = kScaleInPack;
// read B warp tensor from B block tensor
typename WarpGemm::BWarpTensor b_warp_tensor;
// // read B warp tensor from B block tensor
// typename WarpGemm::BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<n_iter, k_iter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// b_warp_tensor.get_thread_buffer() = b_block_tile.get_y_sliced_thread_data(
// merge_sequences(sequence<n_iter, k_iter>{}, b_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
WarpGemm{}.template operator()<OpSelA, OpSelB>(
c_warp_tensors(m_iter)(n_iter),
a_warp_tensor,
b_warp_tensor,
scale_a(m_iter)(number<kScalePacked>{}).get_thread_buffer()[0],
scale_b(n_iter)(number<kScalePacked>{}).get_thread_buffer()[0]);
});
});
});
};
// WarpGemm{}.template operator()<OpSelA, OpSelB>(
// c_warp_tensors(m_iter)(n_iter),
// a_warp_tensor,
// b_warp_tensor,
// scale_a(m_iter)(number<kScalePacked>{}).get_thread_buffer()[0],
// scale_b(n_iter)(number<kScalePacked>{}).get_thread_buffer()[0]);
// });
// });
// });
// };
// write to LDS window(0) must complete before the local prefetch
block_sync_lds_direct_load();
@@ -636,12 +734,16 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
b_tile_windows[number<0>{}],
b_dram_tile_window_step);
// C(i-3) = A(i-3) @ B(i-3) with MX scaling
warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_ping;
ignore = scale_b_tile_ping;
HotLoopScheduler();
// Load scales for iteration i+2 (ping)
if (i_global_read + 2 < num_loop) {
load_scales_(scale_a_tile_ping, scale_b_tile_ping);
}
HotLoopScheduler();
}
// pong
{
@@ -661,13 +763,17 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
b_tile_windows[number<0>{}],
b_dram_tile_window_step);
// C(i-2) = A(i-2) @ B(i-2) with MX scaling
warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
// warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_pong;
ignore = scale_b_tile_pong;
HotLoopScheduler();
// Load scales for iteration i+2 (pong)
/// TODO: check condition
if (i_global_read + 2 < num_loop) {
load_scales_(scale_a_tile_pong, scale_b_tile_pong);
}
HotLoopScheduler();
}
i_global_read += 2;
} while(i_global_read < num_loop);
@@ -681,7 +787,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
// C(num_loop-2) = A(num_loop-2) @ B(num_loop-2) with MX scaling
warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_ping;
ignore = scale_b_tile_ping;
/// TODO: load next scales to ping for the last iteration
}
{
@@ -691,11 +801,19 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) with MX scaling
warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
// warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_pong;
ignore = scale_b_tile_pong;
}
{
// C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling
warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_ping;
ignore = scale_b_tile_ping;
}
}
else if(TailNum == TailNumber::Two)
@@ -706,36 +824,48 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) with MX scaling
warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_ping;
ignore = scale_b_tile_ping;
}
{
// C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling
warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
// warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_pong;
ignore = scale_b_tile_pong;
}
}
else if(TailNum == TailNumber::One)
{
block_sync_lds();
// C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling
warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
/// TODO: remove these after creating a block gemm with scales
ignore = scale_a_tile_ping;
ignore = scale_b_tile_ping;
__builtin_amdgcn_sched_barrier(0);
}
// Convert warp-level C tensors to block tile format
auto c_block_tile = BlockGemm{}.MakeCBlockTile();
using CWarpDstr = typename WarpGemm::CWarpDstr;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// auto c_block_tile = BlockGemm{}.MakeCBlockTile();
// using CWarpDstr = typename WarpGemm::CWarpDstr;
// constexpr auto c_warp_y_lengths =
// to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
// constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensors(mIter)(nIter).get_thread_buffer());
});
});
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// c_block_tile.set_y_sliced_thread_data(
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
// c_warp_tensors(mIter)(nIter).get_thread_buffer());
// });
// });
return c_block_tile;
}

View File

@@ -4,9 +4,11 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include <type_traits>
namespace ck_tile {
// Default policy for MXGemmPipelineAgBgCrCompAsync
@@ -70,91 +72,234 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
return vector_size;
}
// DRAM tile distributions use STORAGE dimensions (for the storage tensor view)
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize; // Use STORAGE dimensions
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
// // DRAM tile distributions use STORAGE dimensions (for the storage tensor view)
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
// {
// // using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
// // using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
using ALayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
// // constexpr index_t BlockSize = Problem::kBlockSize;
// // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
// // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
// // constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
// // constexpr index_t K2 = 16; // 16 bytes
// // constexpr index_t K1 = 128 / K2; // 8
// // constexpr index_t K0 = KPerBlock / (K1 * K2 * APackedSize); // KPerBlock/256/packsize
// // constexpr index_t M2 = get_warp_size() / K1; // 8
// // constexpr index_t M1 = BlockSize / get_warp_size(); // 4
// // constexpr index_t M0 = MPerBlock / (M2 * M1);
// // static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
// // static_assert(K0 * K1 * K2 * APackedSize == KPerBlock,
// // "K0, K1, K2 must cover whole KPerBlock!");
// // return make_static_tile_distribution(
// // tile_distribution_encoding< //
// // sequence<1>,
// // tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
// // tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
// // tuple<sequence<1>, sequence<2, 1>>,
// // sequence<1, 2, 2>, // M0,K0,K2
// // sequence<0, 0, 2>>{});
// constexpr index_t BlockSize = Problem::kBlockSize;
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
// /// NOTE: for flatmm style byte tensor, divide KPerBlock by APackedSize to get STORAGE dimensions
// // using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
// // using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
// // constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
// // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize; // Use STORAGE dimensions
// /// NOTE: use original KPerBlock
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
// constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
// constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
// using ALayout = remove_cvref_t<
// std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
MPerBlock,
KPerBlock, // Use storage dimensions
VecLoadSize,
getATileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
else
{
static_assert(false, "Not implemented");
// using TileEncodingPattern =
// tile_distribution_encoding_pattern_2d<BlockSize,
// KPerBlock,
// MPerBlock,
// VecLoadSize,
// getATileAccessPattern(),
// NumWaveGroups>;
// return TileEncodingPattern::make_2d_static_tile_distribution();
}
}
// if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
// {
// using TileEncodingPattern =
// tile_distribution_encoding_pattern_2d<BlockSize,
// MPerBlock,
// KPerBlock, // Use storage dimensions
// VecLoadSize,
// getATileAccessPattern(),
// NumWaveGroups>;
// return TileEncodingPattern::make_2d_static_tile_distribution();
// }
// else
// {
// static_assert(false, "Not implemented");
// // using TileEncodingPattern =
// // tile_distribution_encoding_pattern_2d<BlockSize,
// // KPerBlock,
// // MPerBlock,
// // VecLoadSize,
// // getATileAccessPattern(),
// // NumWaveGroups>;
// // return TileEncodingPattern::make_2d_static_tile_distribution();
// }
// }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize; // Use STORAGE dimensions
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
// {
// /// NOTE: flatmm style dstr
// // using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
// // using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
// // constexpr index_t BlockSize = Problem::kBlockSize;
// // constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
// // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
// // constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
using BLayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
// // constexpr index_t K2 = 16; // 16 bytes
// // constexpr index_t K1 = 128 / K2; // 8
// // constexpr index_t K0 = KPerBlock / (K1 * K2 * BPackedSize); // KPerBlock/256/packsize
// // constexpr index_t N2 = get_warp_size() / K1; // 8
// // constexpr index_t N1 = BlockSize / get_warp_size(); // 4
// // constexpr index_t N0 = NPerBlock / (N2 * N1);
// // static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!");
// // static_assert(K0 * K1 * K2 * BPackedSize == KPerBlock,
// // "K0, K1, K2 must cover whole KPerBlock!");
// // return make_static_tile_distribution(
// // tile_distribution_encoding< //
// // sequence<1>,
// // tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
// // tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
// // tuple<sequence<1>, sequence<2, 1>>,
// // sequence<1, 2, 2>, // N0,K0,K2
// // sequence<0, 0, 2>>{});
// constexpr index_t BlockSize = Problem::kBlockSize;
// constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
// /// NOTE: for flatmm style byte tensor, divide KPerBlock by BPackedSize to get STORAGE dimensions
// // using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
// // using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
// // constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
// // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize; // Use STORAGE dimensions
// /// NOTE: use original KPerBlock
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
// constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
// constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
// using BLayout = remove_cvref_t<
// std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
static_assert(false, "Not implemented");
}
else
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
NPerBlock,
KPerBlock, // Use storage dimensions
VecLoadSize,
getBTileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
}
// if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
// {
// static_assert(false, "Not implemented");
// }
// else
// {
// using TileEncodingPattern =
// tile_distribution_encoding_pattern_2d<BlockSize,
// NPerBlock,
// KPerBlock, // Use storage dimensions
// VecLoadSize,
// getBTileAccessPattern(),
// NumWaveGroups>;
// return TileEncodingPattern::make_2d_static_tile_distribution();
// }
// }
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDSBytes_TileDistribution()
// {
// // static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1");
// using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
// using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
// constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
// using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
// constexpr index_t MWarps = BlockWarps::at(number<0>{});
// constexpr index_t NWarps = BlockWarps::at(number<1>{});
// constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
// // constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
// constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2);
// constexpr index_t K_Lane = get_warp_size() / 16; // 4
// constexpr index_t K_Thread = KPerXdl / K_Lane; // 32
// constexpr index_t DWORDx4 = 16;
// constexpr index_t AK1 = DWORDx4 * APackedSize;
// if constexpr(K_Thread == AK1)
// return make_static_tile_distribution(
// tile_distribution_encoding< //
// sequence<NWarps>,
// tuple<sequence<MWarps, 1, MPerXdl>, sequence<K_Lane, AK1 / APackedSize>>,
// tuple<sequence<1, 0>, sequence<2, 1>>,
// tuple<sequence<0, 0>, sequence<0, 2>>,
// sequence<2>,
// sequence<1>>{});
// else
// return make_static_tile_distribution(
// tile_distribution_encoding< //
// sequence<NWarps>,
// tuple<sequence<MWarps, 1, MPerXdl>,
// sequence<K_Thread / AK1, K_Lane, AK1 / APackedSize>>,
// tuple<sequence<1, 0>, sequence<2, 1>>,
// tuple<sequence<0, 0>, sequence<1, 2>>,
// sequence<2, 2>,
// sequence<0, 2>>{});
// }
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BLDSBytes_TileDistribution()
// {
// // static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1");
// using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
// using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
// constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
// using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
// constexpr index_t MWarps = BlockWarps::at(number<0>{});
// constexpr index_t NWarps = BlockWarps::at(number<1>{});
// // constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
// constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
// constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2);
// constexpr index_t K_Lane = get_warp_size() / 16; // 4
// constexpr index_t K_Thread = KPerXdl / K_Lane; // 32
// constexpr index_t DWORDx4 = 16;
// constexpr index_t BK1 = DWORDx4 * BPackedSize;
// if constexpr(K_Thread == BK1)
// return make_static_tile_distribution(
// tile_distribution_encoding< //
// sequence<MWarps>,
// tuple<sequence<NWarps, 1, NPerXdl>, sequence<K_Lane, BK1 / BPackedSize>>,
// tuple<sequence<1, 0>, sequence<2, 1>>,
// tuple<sequence<0, 0>, sequence<0, 2>>,
// sequence<2>,
// sequence<1>>{});
// else
// return make_static_tile_distribution(
// tile_distribution_encoding< //
// sequence<MWarps>,
// tuple<sequence<NWarps, 1, NPerXdl>,
// sequence<K_Thread / BK1, K_Lane, BK1 / BPackedSize>>,
// tuple<sequence<1, 0>, sequence<2, 1>>,
// tuple<sequence<0, 0>, sequence<1, 2>>,
// sequence<2, 2>,
// sequence<0, 2>>{});
// }
template <typename Problem,
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
{
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize; // Use STORAGE dimensions
/// NOTE: for flatmm style byte tensor, divide KPerBlock by APackedSize to get STORAGE dimensions
// using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
// using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
// constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize; // Use STORAGE dimensions
/// NOTE: use original KPerBlock
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(is_a_load_tr<Problem>)
{
// TODO: better LDS descriptor for performance
@@ -170,6 +315,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
else
{
constexpr index_t KPack = GetSmemPackA<Problem>();
static_assert(KPack >= 16, "KPack must be at least 16");
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
@@ -190,12 +336,14 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize;
/// NOTE: for flatmm style byte tensor, divide KPerBlock by BPackedSize to get STORAGE dimensions
// using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
// using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
// constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize; // Use STORAGE dimensions
/// NOTE: use original KPerBlock
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(is_b_load_tr<Problem>)
{
// TODO: better LDS descriptor for performance
@@ -211,6 +359,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
else
{
constexpr index_t KPack = GetSmemPackB<Problem>();
static_assert(KPack >= 16, "KPack must be at least 16");
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<NPerBlock>{}, number<KPack>{}),