mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
revert mostly back to original comp_async
This commit is contained in:
@@ -298,38 +298,96 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
|
||||
auto a_tile_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
/// NOTE: flatmm style byte tensor approach:
|
||||
// Create tile window with STORAGE dimensions to match LDS
|
||||
// auto&& tensor_view_tmp = a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view();
|
||||
// auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
|
||||
// const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
|
||||
// auto&& a_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
// static_cast<const uint8_t*>(byte_ptr),
|
||||
// make_tuple(rows, cols / APackedSize),
|
||||
// make_tuple(cols / APackedSize, 1),
|
||||
// number<16>{},
|
||||
// number<1>{});
|
||||
// return make_tile_window(a_tensor_view,
|
||||
// make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{}),
|
||||
// [&]() {
|
||||
// auto origin = a_dram_block_window_tmp[number<idx>{}].get_window_origin();
|
||||
// if constexpr(is_a_col_major) {
|
||||
// origin[0] = origin[0] / APackedSize; // Adjust K origin
|
||||
// } else {
|
||||
// origin[1] = origin[1] / APackedSize; // Adjust K origin
|
||||
// }
|
||||
// return origin;
|
||||
// }(),
|
||||
// Policy::template MakeADramTileDistribution<Problem>());
|
||||
/// NOTE: re-use original tensor view but with adjusted origin and K/PackedSize
|
||||
// return make_tile_window(
|
||||
// a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
// make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{}),
|
||||
// [&]() {
|
||||
// auto origin = a_dram_block_window_tmp[number<idx>{}].get_window_origin();
|
||||
// if constexpr(is_a_col_major) {
|
||||
// origin[0] = origin[0] / APackedSize; // Adjust K origin
|
||||
// } else {
|
||||
// origin[1] = origin[1] / APackedSize; // Adjust K origin
|
||||
// }
|
||||
// return origin;
|
||||
// }(),
|
||||
// Policy::template MakeADramTileDistribution<Problem>());
|
||||
/// NOTE: use original shapes
|
||||
return make_tile_window(
|
||||
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{}),
|
||||
[&]() {
|
||||
auto origin = a_dram_block_window_tmp[number<idx>{}].get_window_origin();
|
||||
if constexpr(is_a_col_major) {
|
||||
origin[0] = origin[0] / APackedSize; // Adjust K origin
|
||||
} else {
|
||||
origin[1] = origin[1] / APackedSize; // Adjust K origin
|
||||
}
|
||||
return origin;
|
||||
}(),
|
||||
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
},
|
||||
number<AsLayout::size()>{});
|
||||
// B DRAM window(s) for load
|
||||
auto b_tile_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
/// NOTE: flatmm style byte tensor approach:
|
||||
// Create tile window with STORAGE dimensions to match LDS
|
||||
// auto&& tensor_view_tmp = b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view();
|
||||
// auto&& byte_ptr = reinterpret_cast<const uint8_t*>(&(tensor_view_tmp.get_buffer_view()(0)));
|
||||
// const auto [rows, cols] = tensor_view_tmp.get_tensor_descriptor().get_lengths();
|
||||
// auto&& b_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
// static_cast<const uint8_t*>(byte_ptr),
|
||||
// make_tuple(rows, cols / BPackedSize),
|
||||
// make_tuple(cols / BPackedSize, 1),
|
||||
// number<16>{},
|
||||
// number<1>{});
|
||||
// return make_tile_window(b_tensor_view,
|
||||
// make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{}),
|
||||
// [&]() {
|
||||
// auto origin = b_dram_block_window_tmp[number<idx>{}].get_window_origin();
|
||||
// if constexpr(is_b_row_major) {
|
||||
// origin[0] = origin[0] / BPackedSize; // Adjust K origin
|
||||
// } else {
|
||||
// origin[1] = origin[1] / BPackedSize; // Adjust K origin
|
||||
// }
|
||||
// return origin;
|
||||
// }(),
|
||||
// Policy::template MakeBDramTileDistribution<Problem>());
|
||||
/// NOTE: re-use original tensor view but with adjusted origin and K/PackedSize
|
||||
// return make_tile_window(
|
||||
// b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
// make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{}),
|
||||
// [&]() {
|
||||
// auto origin = b_dram_block_window_tmp[number<idx>{}].get_window_origin();
|
||||
// if constexpr(is_b_row_major) {
|
||||
// origin[0] = origin[0] / BPackedSize; // Adjust K origin
|
||||
// } else {
|
||||
// origin[1] = origin[1] / BPackedSize; // Adjust K origin
|
||||
// }
|
||||
// return origin;
|
||||
// }(),
|
||||
// Policy::template MakeBDramTileDistribution<Problem>());
|
||||
/// NOTE: use original shapes
|
||||
return make_tile_window(
|
||||
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{}),
|
||||
[&]() {
|
||||
auto origin = b_dram_block_window_tmp[number<idx>{}].get_window_origin();
|
||||
if constexpr(is_b_row_major) {
|
||||
origin[0] = origin[0] / BPackedSize; // Adjust K origin
|
||||
} else {
|
||||
origin[1] = origin[1] / BPackedSize; // Adjust K origin
|
||||
}
|
||||
return origin;
|
||||
}(),
|
||||
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
},
|
||||
number<BsLayout::size()>{});
|
||||
@@ -382,22 +440,41 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
|
||||
// this pipeline has a pair of LDS buffers per logical tile
|
||||
// TODO: check for packed size - are these blocks too big?
|
||||
/// NOTE: flatmm style byte tensor approach:
|
||||
// auto&& [a_lds_block0, b_lds_block0] = Base::template GetABLdsTensorViews<uint8_t, uint8_t>(p_smem_0);
|
||||
// auto&& [a_lds_block1, b_lds_block1] = Base::template GetABLdsTensorViews<uint8_t, uint8_t>(p_smem_1);
|
||||
/// NOTE: with original fp4 types:
|
||||
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
|
||||
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
|
||||
|
||||
// set up LDS tile shapes - always use STORAGE dimensions for K
|
||||
/// NOTE: flatmm style byte tensor approach:
|
||||
// constexpr auto a_lds_shape = []() {
|
||||
// if constexpr(is_a_load_tr_v)
|
||||
// return make_tuple(number<KPerBlock / APackedSize>{}, number<MPerBlock>{});
|
||||
// else
|
||||
// return make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{});
|
||||
// }();
|
||||
|
||||
// constexpr auto b_lds_shape = []() {
|
||||
// if constexpr(is_b_load_tr_v)
|
||||
// return make_tuple(number<KPerBlock / BPackedSize>{}, number<NPerBlock>{});
|
||||
// else
|
||||
// return make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{});
|
||||
// }();
|
||||
/// NOTE: use original shapes
|
||||
constexpr auto a_lds_shape = []() {
|
||||
if constexpr(is_a_load_tr_v)
|
||||
return make_tuple(number<KPerBlock / APackedSize>{}, number<MPerBlock>{});
|
||||
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<MPerBlock>{}, number<KPerBlock / APackedSize>{});
|
||||
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
|
||||
}();
|
||||
|
||||
constexpr auto b_lds_shape = []() {
|
||||
if constexpr(is_b_load_tr_v)
|
||||
return make_tuple(number<KPerBlock / BPackedSize>{}, number<NPerBlock>{});
|
||||
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<NPerBlock>{}, number<KPerBlock / BPackedSize>{});
|
||||
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
|
||||
}();
|
||||
|
||||
// LDS tile windows for storing, one per LDS buffer
|
||||
@@ -413,10 +490,16 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
/// NOTE: flatmm style way to calculate steps with packed size
|
||||
// constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
// is_a_col_major ? make_array(KPerBlock / APackedSize, 0) : make_array(0, KPerBlock / APackedSize);
|
||||
// constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
// is_b_row_major ? make_array(KPerBlock / BPackedSize, 0) : make_array(0, KPerBlock / BPackedSize);
|
||||
/// NOTE: use original steps and assume that PackedSize is correctly applied elsewhere
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock / APackedSize, 0) : make_array(0, KPerBlock / APackedSize);
|
||||
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock / BPackedSize, 0) : make_array(0, KPerBlock / BPackedSize);
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
|
||||
// read A(0), B(0) from DRAM to LDS window(0)
|
||||
// and advance the DRAM windows
|
||||
@@ -426,8 +509,13 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
|
||||
// Initialize WarpGemm for MX scaling
|
||||
using WarpGemm = typename remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>::WarpGemm;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
// using WarpGemm = typename remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>::WarpGemm;
|
||||
// using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
// Initialize block gemm and C block tile
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
clear_tile(c_block_tile);
|
||||
|
||||
// read A(1), B(1) from DRAM to LDS window(1)
|
||||
// and advance the DRAM windows
|
||||
@@ -449,6 +537,7 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
ALdsTile a_block_tile0, a_block_tile1;
|
||||
BLdsTile b_block_tile0, b_block_tile1;
|
||||
|
||||
// Some sanity checks on the LDS tile sizes
|
||||
static_assert(sizeof(ALdsTile) == MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize) * NWarp / BlockSize, "ALdsTile size is wrong!");
|
||||
static_assert(sizeof(BLdsTile) == NPerBlock * (KPerBlock * sizeof(BDataType) / BPackedSize) * MWarp / BlockSize, "BLdsTile size is wrong!");
|
||||
static_assert(Policy::template GetSmemSizeA<Problem>() == MPerBlock * (KPerBlock * sizeof(ADataType) / APackedSize), "SmemSizeA size is wrong!");
|
||||
@@ -496,36 +585,44 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
move_tile_window(scale_b_dram_window, {KPerBlock / ScaleBlockSize / KXdlPack, 0});
|
||||
};
|
||||
|
||||
constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() {
|
||||
if constexpr(is_a_load_tr_v)
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
typename decltype(ALdsTileDistr)::DstrEncode,
|
||||
typename Problem::ADataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return ALdsTileDistr;
|
||||
}();
|
||||
constexpr auto b_lds_input_tile_distr = [BLdsTileDistr]() {
|
||||
if constexpr(is_b_load_tr_v)
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
typename decltype(BLdsTileDistr)::DstrEncode,
|
||||
typename Problem::BDataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return BLdsTileDistr;
|
||||
}();
|
||||
// constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() {
|
||||
// if constexpr(is_a_load_tr_v)
|
||||
// return make_static_tile_distribution(
|
||||
// typename InputTileDistributionTraits<
|
||||
// typename decltype(ALdsTileDistr)::DstrEncode,
|
||||
// typename Problem::ADataType>::TransposedDstrEncode{});
|
||||
// else
|
||||
// return ALdsTileDistr;
|
||||
// }();
|
||||
// constexpr auto b_lds_input_tile_distr = [BLdsTileDistr]() {
|
||||
// if constexpr(is_b_load_tr_v)
|
||||
// return make_static_tile_distribution(
|
||||
// typename InputTileDistributionTraits<
|
||||
// typename decltype(BLdsTileDistr)::DstrEncode,
|
||||
// typename Problem::BDataType>::TransposedDstrEncode{});
|
||||
// else
|
||||
// return BLdsTileDistr;
|
||||
// }();
|
||||
|
||||
// LDS tile windows for reading;
|
||||
// they share the data pointer with the LDS windows for storing
|
||||
// but also associate with a distribution to produce a register tile when reading
|
||||
auto a_lds_ld_window0 =
|
||||
make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
|
||||
make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, ALdsTileDistr);
|
||||
auto a_lds_ld_window1 =
|
||||
make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
|
||||
make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, ALdsTileDistr);
|
||||
auto b_lds_ld_window0 =
|
||||
make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
|
||||
make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, BLdsTileDistr);
|
||||
auto b_lds_ld_window1 =
|
||||
make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
|
||||
make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, BLdsTileDistr);
|
||||
// auto a_lds_ld_window0 =
|
||||
// make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, Policy::template MakeMX_ALDSBytes_TileDistribution<Problem>());
|
||||
// auto a_lds_ld_window1 =
|
||||
// make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, Policy::template MakeMX_ALDSBytes_TileDistribution<Problem>());
|
||||
// auto b_lds_ld_window0 =
|
||||
// make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, Policy::template MakeMX_BLDSBytes_TileDistribution<Problem>());
|
||||
// auto b_lds_ld_window1 =
|
||||
// make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, Policy::template MakeMX_BLDSBytes_TileDistribution<Problem>());
|
||||
|
||||
static_assert(!(is_tile_window_linear_v<decltype(a_lds_ld_window0)>) &&
|
||||
!(is_tile_window_linear_v<decltype(a_lds_ld_window1)>) &&
|
||||
@@ -534,61 +631,62 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
"LDS windows must not be linear");
|
||||
|
||||
// Create warp-level C tensors (one per M/N iteration)
|
||||
statically_indexed_array<statically_indexed_array<CWarpTensor, NIterPerWarp>, MIterPerWarp> c_warp_tensors;
|
||||
// statically_indexed_array<statically_indexed_array<CWarpTensor, NIterPerWarp>, MIterPerWarp> c_warp_tensors;
|
||||
|
||||
// Initialize C tensors
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
clear_tile(c_warp_tensors(mIter)(nIter));
|
||||
});
|
||||
});
|
||||
/// TODO: create CBlockTile with block_gemm.MakeCBlockTile()
|
||||
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// clear_tile(c_warp_tensors(mIter)(nIter));
|
||||
// });
|
||||
// });
|
||||
|
||||
// Warp GEMM loop with MX scaling
|
||||
auto warp_gemm_loop = [&](const auto& a_block_tile, const auto& b_block_tile, const auto& scale_a, const auto& scale_b) {
|
||||
// Extract A/B values from block tiles to warp iteration structure
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(typename WarpGemm::AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<WarpGemm::AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_lengths =
|
||||
to_sequence(typename WarpGemm::BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<WarpGemm::BWarpDstr::NDimY, 0>{};
|
||||
// auto warp_gemm_loop = [&](const auto& a_block_tile, const auto& b_block_tile, const auto& scale_a, const auto& scale_b) {
|
||||
// // Extract A/B values from block tiles to warp iteration structure
|
||||
// constexpr auto a_warp_y_lengths =
|
||||
// to_sequence(typename WarpGemm::AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
// constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<WarpGemm::AWarpDstr::NDimY, 0>{};
|
||||
// constexpr auto b_warp_y_lengths =
|
||||
// to_sequence(typename WarpGemm::BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
// constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<WarpGemm::BWarpDstr::NDimY, 0>{};
|
||||
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto k_iter) {
|
||||
// Map k_iter to packed scale index and OpSel
|
||||
constexpr index_t kScalePacked = (k_iter * KPerXdl) / (ScaleBlockSize * KXdlPack);
|
||||
// constexpr index_t kScaleInPack = ((k_iter * KPerXdl) / ScaleBlockSize) % KXdlPack;
|
||||
constexpr index_t kScaleInPack = k_iter;
|
||||
// static_for<0, KIterPerWarp, 1>{}([&](auto k_iter) {
|
||||
// // Map k_iter to packed scale index and OpSel
|
||||
// constexpr index_t kScalePacked = (k_iter * KPerXdl) / (ScaleBlockSize * KXdlPack);
|
||||
// // constexpr index_t kScaleInPack = ((k_iter * KPerXdl) / ScaleBlockSize) % KXdlPack;
|
||||
// constexpr index_t kScaleInPack = k_iter;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto m_iter) {
|
||||
constexpr auto OpSelA = kScaleInPack;
|
||||
// static_for<0, MIterPerWarp, 1>{}([&](auto m_iter) {
|
||||
// constexpr auto OpSelA = kScaleInPack;
|
||||
|
||||
// read A warp tensor from A block tensor
|
||||
typename WarpGemm::AWarpTensor a_warp_tensor;
|
||||
// // read A warp tensor from A block tensor
|
||||
// typename WarpGemm::AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<m_iter, k_iter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
// a_warp_tensor.get_thread_buffer() = a_block_tile.get_y_sliced_thread_data(
|
||||
// merge_sequences(sequence<m_iter, k_iter>{}, a_warp_y_index_zeros),
|
||||
// merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto n_iter) {
|
||||
constexpr auto OpSelB = kScaleInPack;
|
||||
// static_for<0, NIterPerWarp, 1>{}([&](auto n_iter) {
|
||||
// constexpr auto OpSelB = kScaleInPack;
|
||||
|
||||
// read B warp tensor from B block tensor
|
||||
typename WarpGemm::BWarpTensor b_warp_tensor;
|
||||
// // read B warp tensor from B block tensor
|
||||
// typename WarpGemm::BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<n_iter, k_iter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
// b_warp_tensor.get_thread_buffer() = b_block_tile.get_y_sliced_thread_data(
|
||||
// merge_sequences(sequence<n_iter, k_iter>{}, b_warp_y_index_zeros),
|
||||
// merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
WarpGemm{}.template operator()<OpSelA, OpSelB>(
|
||||
c_warp_tensors(m_iter)(n_iter),
|
||||
a_warp_tensor,
|
||||
b_warp_tensor,
|
||||
scale_a(m_iter)(number<kScalePacked>{}).get_thread_buffer()[0],
|
||||
scale_b(n_iter)(number<kScalePacked>{}).get_thread_buffer()[0]);
|
||||
});
|
||||
});
|
||||
});
|
||||
};
|
||||
// WarpGemm{}.template operator()<OpSelA, OpSelB>(
|
||||
// c_warp_tensors(m_iter)(n_iter),
|
||||
// a_warp_tensor,
|
||||
// b_warp_tensor,
|
||||
// scale_a(m_iter)(number<kScalePacked>{}).get_thread_buffer()[0],
|
||||
// scale_b(n_iter)(number<kScalePacked>{}).get_thread_buffer()[0]);
|
||||
// });
|
||||
// });
|
||||
// });
|
||||
// };
|
||||
|
||||
// write to LDS window(0) must complete before the local prefetch
|
||||
block_sync_lds_direct_load();
|
||||
@@ -636,12 +734,16 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
b_tile_windows[number<0>{}],
|
||||
b_dram_tile_window_step);
|
||||
// C(i-3) = A(i-3) @ B(i-3) with MX scaling
|
||||
warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_ping;
|
||||
ignore = scale_b_tile_ping;
|
||||
HotLoopScheduler();
|
||||
// Load scales for iteration i+2 (ping)
|
||||
if (i_global_read + 2 < num_loop) {
|
||||
load_scales_(scale_a_tile_ping, scale_b_tile_ping);
|
||||
}
|
||||
HotLoopScheduler();
|
||||
}
|
||||
// pong
|
||||
{
|
||||
@@ -661,13 +763,17 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
b_tile_windows[number<0>{}],
|
||||
b_dram_tile_window_step);
|
||||
// C(i-2) = A(i-2) @ B(i-2) with MX scaling
|
||||
warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
|
||||
// warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_pong;
|
||||
ignore = scale_b_tile_pong;
|
||||
HotLoopScheduler();
|
||||
// Load scales for iteration i+2 (pong)
|
||||
/// TODO: check condition
|
||||
if (i_global_read + 2 < num_loop) {
|
||||
load_scales_(scale_a_tile_pong, scale_b_tile_pong);
|
||||
}
|
||||
HotLoopScheduler();
|
||||
}
|
||||
i_global_read += 2;
|
||||
} while(i_global_read < num_loop);
|
||||
@@ -681,7 +787,11 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
|
||||
// C(num_loop-2) = A(num_loop-2) @ B(num_loop-2) with MX scaling
|
||||
warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_ping;
|
||||
ignore = scale_b_tile_ping;
|
||||
/// TODO: load next scales to ping for the last iteration
|
||||
}
|
||||
{
|
||||
@@ -691,11 +801,19 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
|
||||
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) with MX scaling
|
||||
warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
|
||||
// warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_pong;
|
||||
ignore = scale_b_tile_pong;
|
||||
}
|
||||
{
|
||||
// C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling
|
||||
warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_ping;
|
||||
ignore = scale_b_tile_ping;
|
||||
}
|
||||
}
|
||||
else if(TailNum == TailNumber::Two)
|
||||
@@ -706,36 +824,48 @@ struct MXGemmPipelineAgBgCrCompAsync : public BaseMXGemmPipelineAgBgCrCompAsync<
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
|
||||
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1) with MX scaling
|
||||
warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_ping;
|
||||
ignore = scale_b_tile_ping;
|
||||
}
|
||||
{
|
||||
// C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling
|
||||
warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
|
||||
// warp_gemm_loop(a_block_tile1, b_block_tile1, scale_a_tile_pong, scale_b_tile_pong);
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_pong;
|
||||
ignore = scale_b_tile_pong;
|
||||
}
|
||||
}
|
||||
else if(TailNum == TailNumber::One)
|
||||
{
|
||||
block_sync_lds();
|
||||
// C(num_loop) = A(num_loop) @ B(num_loop) with MX scaling
|
||||
warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
// warp_gemm_loop(a_block_tile0, b_block_tile0, scale_a_tile_ping, scale_b_tile_ping);
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
/// TODO: remove these after creating a block gemm with scales
|
||||
ignore = scale_a_tile_ping;
|
||||
ignore = scale_b_tile_ping;
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// Convert warp-level C tensors to block tile format
|
||||
auto c_block_tile = BlockGemm{}.MakeCBlockTile();
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
// auto c_block_tile = BlockGemm{}.MakeCBlockTile();
|
||||
// using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
// constexpr auto c_warp_y_lengths =
|
||||
// to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
// constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensors(mIter)(nIter).get_thread_buffer());
|
||||
});
|
||||
});
|
||||
// static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// c_block_tile.set_y_sliced_thread_data(
|
||||
// merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
// merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
// c_warp_tensors(mIter)(nIter).get_thread_buffer());
|
||||
// });
|
||||
// });
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
@@ -4,9 +4,11 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
// Default policy for MXGemmPipelineAgBgCrCompAsync
|
||||
@@ -70,91 +72,234 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
return vector_size;
|
||||
}
|
||||
|
||||
// DRAM tile distributions use STORAGE dimensions (for the storage tensor view)
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize; // Use STORAGE dimensions
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
// // DRAM tile distributions use STORAGE dimensions (for the storage tensor view)
|
||||
// template <typename Problem>
|
||||
// CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
// {
|
||||
// // using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
// // using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
|
||||
using ALayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
|
||||
// // constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
// // constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
// // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
// // constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
|
||||
// // constexpr index_t K2 = 16; // 16 bytes
|
||||
// // constexpr index_t K1 = 128 / K2; // 8
|
||||
// // constexpr index_t K0 = KPerBlock / (K1 * K2 * APackedSize); // KPerBlock/256/packsize
|
||||
|
||||
// // constexpr index_t M2 = get_warp_size() / K1; // 8
|
||||
// // constexpr index_t M1 = BlockSize / get_warp_size(); // 4
|
||||
// // constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
|
||||
// // static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
|
||||
// // static_assert(K0 * K1 * K2 * APackedSize == KPerBlock,
|
||||
// // "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
// // return make_static_tile_distribution(
|
||||
// // tile_distribution_encoding< //
|
||||
// // sequence<1>,
|
||||
// // tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
|
||||
// // tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
|
||||
// // tuple<sequence<1>, sequence<2, 1>>,
|
||||
// // sequence<1, 2, 2>, // M0,K0,K2
|
||||
// // sequence<0, 0, 2>>{});
|
||||
// constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
// /// NOTE: for flatmm style byte tensor, divide KPerBlock by APackedSize to get STORAGE dimensions
|
||||
// // using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
// // using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
// // constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
// // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize; // Use STORAGE dimensions
|
||||
// /// NOTE: use original KPerBlock
|
||||
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
// constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
// constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
// using ALayout = remove_cvref_t<
|
||||
// std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
|
||||
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock, // Use storage dimensions
|
||||
VecLoadSize,
|
||||
getATileAccessPattern(),
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Not implemented");
|
||||
// using TileEncodingPattern =
|
||||
// tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
// KPerBlock,
|
||||
// MPerBlock,
|
||||
// VecLoadSize,
|
||||
// getATileAccessPattern(),
|
||||
// NumWaveGroups>;
|
||||
// return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
}
|
||||
// if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
// {
|
||||
// using TileEncodingPattern =
|
||||
// tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
// MPerBlock,
|
||||
// KPerBlock, // Use storage dimensions
|
||||
// VecLoadSize,
|
||||
// getATileAccessPattern(),
|
||||
// NumWaveGroups>;
|
||||
// return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// static_assert(false, "Not implemented");
|
||||
// // using TileEncodingPattern =
|
||||
// // tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
// // KPerBlock,
|
||||
// // MPerBlock,
|
||||
// // VecLoadSize,
|
||||
// // getATileAccessPattern(),
|
||||
// // NumWaveGroups>;
|
||||
// // return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
// }
|
||||
// }
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize; // Use STORAGE dimensions
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
// template <typename Problem>
|
||||
// CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
// {
|
||||
// /// NOTE: flatmm style dstr
|
||||
// // using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
// // using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
|
||||
// // constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
// // constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
// // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
// // constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using BLayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
|
||||
// // constexpr index_t K2 = 16; // 16 bytes
|
||||
// // constexpr index_t K1 = 128 / K2; // 8
|
||||
// // constexpr index_t K0 = KPerBlock / (K1 * K2 * BPackedSize); // KPerBlock/256/packsize
|
||||
|
||||
// // constexpr index_t N2 = get_warp_size() / K1; // 8
|
||||
// // constexpr index_t N1 = BlockSize / get_warp_size(); // 4
|
||||
// // constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
|
||||
// // static_assert(N0 * N1 * N2 == NPerBlock, "N0, N1, N2 must cover whole NPerBlock!");
|
||||
// // static_assert(K0 * K1 * K2 * BPackedSize == KPerBlock,
|
||||
// // "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
// // return make_static_tile_distribution(
|
||||
// // tile_distribution_encoding< //
|
||||
// // sequence<1>,
|
||||
// // tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
|
||||
// // tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
|
||||
// // tuple<sequence<1>, sequence<2, 1>>,
|
||||
// // sequence<1, 2, 2>, // N0,K0,K2
|
||||
// // sequence<0, 0, 2>>{});
|
||||
// constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
// constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
// /// NOTE: for flatmm style byte tensor, divide KPerBlock by BPackedSize to get STORAGE dimensions
|
||||
// // using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
// // using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
// // constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
// // constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize; // Use STORAGE dimensions
|
||||
// /// NOTE: use original KPerBlock
|
||||
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
// constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
// constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
// using BLayout = remove_cvref_t<
|
||||
// std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
|
||||
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
static_assert(false, "Not implemented");
|
||||
}
|
||||
else
|
||||
{
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlock, // Use storage dimensions
|
||||
VecLoadSize,
|
||||
getBTileAccessPattern(),
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
}
|
||||
// if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
// {
|
||||
// static_assert(false, "Not implemented");
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// using TileEncodingPattern =
|
||||
// tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
// NPerBlock,
|
||||
// KPerBlock, // Use storage dimensions
|
||||
// VecLoadSize,
|
||||
// getBTileAccessPattern(),
|
||||
// NumWaveGroups>;
|
||||
// return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
// }
|
||||
// }
|
||||
|
||||
// template <typename Problem>
|
||||
// CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDSBytes_TileDistribution()
|
||||
// {
|
||||
// // static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
// using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
// using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
// constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
// using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
// constexpr index_t MWarps = BlockWarps::at(number<0>{});
|
||||
// constexpr index_t NWarps = BlockWarps::at(number<1>{});
|
||||
// constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
// // constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
// constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2);
|
||||
// constexpr index_t K_Lane = get_warp_size() / 16; // 4
|
||||
// constexpr index_t K_Thread = KPerXdl / K_Lane; // 32
|
||||
// constexpr index_t DWORDx4 = 16;
|
||||
// constexpr index_t AK1 = DWORDx4 * APackedSize;
|
||||
|
||||
// if constexpr(K_Thread == AK1)
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding< //
|
||||
// sequence<NWarps>,
|
||||
// tuple<sequence<MWarps, 1, MPerXdl>, sequence<K_Lane, AK1 / APackedSize>>,
|
||||
// tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
// tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
// sequence<2>,
|
||||
// sequence<1>>{});
|
||||
// else
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding< //
|
||||
// sequence<NWarps>,
|
||||
// tuple<sequence<MWarps, 1, MPerXdl>,
|
||||
// sequence<K_Thread / AK1, K_Lane, AK1 / APackedSize>>,
|
||||
// tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
// tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
// sequence<2, 2>,
|
||||
// sequence<0, 2>>{});
|
||||
// }
|
||||
|
||||
// template <typename Problem>
|
||||
// CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BLDSBytes_TileDistribution()
|
||||
// {
|
||||
// // static_assert(BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
// using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
// using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
// constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
// using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
// constexpr index_t MWarps = BlockWarps::at(number<0>{});
|
||||
// constexpr index_t NWarps = BlockWarps::at(number<1>{});
|
||||
// // constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
// constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
// constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2);
|
||||
// constexpr index_t K_Lane = get_warp_size() / 16; // 4
|
||||
// constexpr index_t K_Thread = KPerXdl / K_Lane; // 32
|
||||
// constexpr index_t DWORDx4 = 16;
|
||||
// constexpr index_t BK1 = DWORDx4 * BPackedSize;
|
||||
|
||||
// if constexpr(K_Thread == BK1)
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding< //
|
||||
// sequence<MWarps>,
|
||||
// tuple<sequence<NWarps, 1, NPerXdl>, sequence<K_Lane, BK1 / BPackedSize>>,
|
||||
// tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
// tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
// sequence<2>,
|
||||
// sequence<1>>{});
|
||||
// else
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding< //
|
||||
// sequence<MWarps>,
|
||||
// tuple<sequence<NWarps, 1, NPerXdl>,
|
||||
// sequence<K_Thread / BK1, K_Lane, BK1 / BPackedSize>>,
|
||||
// tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
// tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
// sequence<2, 2>,
|
||||
// sequence<0, 2>>{});
|
||||
// }
|
||||
|
||||
template <typename Problem,
|
||||
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
|
||||
{
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize; // Use STORAGE dimensions
|
||||
/// NOTE: for flatmm style byte tensor, divide KPerBlock by APackedSize to get STORAGE dimensions
|
||||
// using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
// using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
// constexpr index_t APackedSize = numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / APackedSize; // Use STORAGE dimensions
|
||||
/// NOTE: use original KPerBlock
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
if constexpr(is_a_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better LDS descriptor for performance
|
||||
@@ -170,6 +315,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
else
|
||||
{
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
static_assert(KPack >= 16, "KPack must be at least 16");
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
@@ -190,12 +336,14 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize;
|
||||
/// NOTE: for flatmm style byte tensor, divide KPerBlock by BPackedSize to get STORAGE dimensions
|
||||
// using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
// using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
// constexpr index_t BPackedSize = numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
// constexpr index_t KPerBlock = Problem::BlockGemmShape::kK / BPackedSize; // Use STORAGE dimensions
|
||||
/// NOTE: use original KPerBlock
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
if constexpr(is_b_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better LDS descriptor for performance
|
||||
@@ -211,6 +359,7 @@ struct MXGemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
else
|
||||
{
|
||||
constexpr index_t KPack = GetSmemPackB<Problem>();
|
||||
static_assert(KPack >= 16, "KPack must be at least 16");
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<NPerBlock>{}, number<KPack>{}),
|
||||
|
||||
Reference in New Issue
Block a user