refine code

This commit is contained in:
solin
2025-04-17 06:10:19 +00:00
parent 7048306c28
commit eb2c93186a

View File

@@ -71,10 +71,6 @@ struct BlockFlatmmASmemBSmemCRegV1
const ABlockWindow& a_block_window,
BFlatBlockTensor& b_warp_tensor) const
{
// static_assert(std::is_same_v<ADataType, typename ABlockWindow::DataType> &&
// std::is_same_v<BDataType, typename BFlatBlockWindow::DataType> &&
// std::is_same_v<CDataType, typename CBlockTensor::DataType>,
// "wrong!");
constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}];
@@ -94,9 +90,6 @@ struct BlockFlatmmASmemBSmemCRegV1
constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp;
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
// constexpr index_t NFlatPerBlockPerIter = BlockGemmShape::flatNPerWarp;
// constexpr index_t KFlatPerBlockPerIter = BlockGemmShape::flatKPerWarp;
const index_t iMWarp = get_warp_id() / NWarp;
// construct A-warp-window
@@ -118,24 +111,6 @@ struct BlockFlatmmASmemBSmemCRegV1
});
});
// construct Bflat-warp-window
// auto b_flat_warp_windows_tmp = b_flat_block_window;
// statically_indexed_array<
// statically_indexed_array<decltype(b_flat_warp_windows_tmp), KIterPerWarp>,
// NIterPerWarp>
// b_flat_warp_windows;
// static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// b_flat_warp_windows(nIter)(kIter) = b_flat_warp_windows_tmp;
// move_tile_window(b_flat_warp_windows(nIter)(kIter),
// {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
// });
// });
// auto b_warp_windows = b_origin_warp_windows;
// auto b_warp_windows = b_flat_warp_windows;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
@@ -150,9 +125,6 @@ struct BlockFlatmmASmemBSmemCRegV1
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
// const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
@@ -172,16 +144,6 @@ struct BlockFlatmmASmemBSmemCRegV1
});
});
}
// // C = A * B
// template <typename ABlockTensorTmp, typename BFlatBlockWindow>
// CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
// const BFlatBlockWindow& b_flat_block_window) const
// {
// auto c_block_tensor = MakeCBlockTile();
// operator()(c_block_tensor, a_block_tensor_tmp, b_flat_block_window);
// return c_block_tensor;
// }
};
} // namespace ck_tile