mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Add Flatmm MX FP8 (#3208)
* Use async for flatmm mxfp4 * Fix preshuffle * Add flatmm mxfp8 * Thanks, Copilot * Thanks Copilot again~
This commit is contained in:
@@ -41,6 +41,8 @@ using long_number = constant<v>;
|
||||
|
||||
template <bool b>
|
||||
using bool_constant = constant<b>;
|
||||
using true_type = bool_constant<true>;
|
||||
using false_type = bool_constant<false>;
|
||||
|
||||
#define CK_TILE_LEFT_UNARY_OP(OP) \
|
||||
template <auto x> \
|
||||
|
||||
@@ -21,9 +21,10 @@ namespace ck_tile {
|
||||
template <typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
typename = std::enable_if_t<std::is_class_v<TileWindow_>>>
|
||||
typename offset_t,
|
||||
typename = std::enable_if_t<std::is_class_v<TileWindow_>>>
|
||||
CK_TILE_DEVICE auto load_tile_with_offset(const TileWindow_& tile_window,
|
||||
index_t offset,
|
||||
offset_t offset,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
@@ -67,11 +68,12 @@ template <typename DistributedTensor_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
typename offset_t,
|
||||
typename = std::enable_if_t<std::is_class_v<std::remove_cv_t<DistributedTensor_>> &&
|
||||
std::is_class_v<TileWindow_>>>
|
||||
CK_TILE_DEVICE auto load_tile_with_offset(DistributedTensor_& dst_tile,
|
||||
const TileWindow_& tile_window,
|
||||
index_t offset,
|
||||
offset_t offset,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
@@ -147,29 +149,31 @@ template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool static_move_ys = false,
|
||||
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>> &&
|
||||
std::is_class_v<TileWindow_>>>
|
||||
CK_TILE_DEVICE auto async_load_tile_with_offset(LdsTileWindow_&& lds_tile,
|
||||
CK_TILE_DEVICE void async_load_tile_with_offset(LdsTileWindow_&& lds_tile,
|
||||
const TileWindow_& tile_window,
|
||||
index_t offset,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> occ = {},
|
||||
bool_constant<static_move_ys> smy = {})
|
||||
{
|
||||
return tile_window.async_load_with_offset(
|
||||
offset, lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
tile_window.async_load_with_offset(offset, lds_tile, number<i_access>{}, occ, smy);
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
|
||||
bool oob_conditional_check = true,
|
||||
bool static_move_ys = false>
|
||||
CK_TILE_DEVICE void async_load_tile(LdsTileWindow_&& lds_tile,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> occ = {},
|
||||
bool_constant<static_move_ys> smy = {})
|
||||
{
|
||||
return async_load_tile_with_offset(
|
||||
lds_tile, tile_window, 0, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
async_load_tile_with_offset(lds_tile, tile_window, 0, number<i_access>{}, occ, smy);
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
@@ -177,19 +181,19 @@ template <typename LdsTileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
CK_TILE_DEVICE void async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
return tile_window.async_load_raw(lds_tile,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
tile_window.async_load_raw(lds_tile,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
|
||||
CK_TILE_DEVICE void async_load_fence(index_t cnt = 0)
|
||||
{
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
@@ -166,8 +166,8 @@ struct tensor_view
|
||||
{
|
||||
return buf_.template async_get<X>(
|
||||
smem,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coord.get_offset() / PackedSize + linear_offset / PackedSize,
|
||||
0, // linear_offset need to be imm and is not supported currently
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
@@ -156,8 +156,10 @@ struct tile_window_with_static_distribution
|
||||
0, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_with_offset(index_t offset,
|
||||
template <index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true,
|
||||
typename offset_t = index_t>
|
||||
CK_TILE_DEVICE auto load_with_offset(offset_t offset,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
@@ -291,14 +293,16 @@ struct tile_window_with_static_distribution
|
||||
0, dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
template <typename DataType,
|
||||
typename StaticTileDistribution,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true,
|
||||
typename = std::enable_if_t<std::is_class_v<std::remove_cv_t<DistributedTensor>>>>
|
||||
CK_TILE_DEVICE auto load_with_offset(index_t offset,
|
||||
DistributedTensor& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
typename offset_t>
|
||||
CK_TILE_DEVICE void load_with_offset( //
|
||||
offset_t offset,
|
||||
static_distributed_tensor<DataType, StaticTileDistribution>& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = typename Base::Traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
@@ -306,6 +310,19 @@ struct tile_window_with_static_distribution
|
||||
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
const index_t linear_off = [&]() {
|
||||
if constexpr(std::is_integral_v<offset_t>)
|
||||
return offset;
|
||||
else if constexpr(is_constant_v<offset_t>)
|
||||
return offset_t::value;
|
||||
else
|
||||
{
|
||||
auto bottom_tensor_idx_off = to_multi_index(offset_t{});
|
||||
auto bottom_tensor_coord_off = make_tensor_coordinate(
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off);
|
||||
return bottom_tensor_coord_off.get_offset();
|
||||
}
|
||||
}();
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
@@ -321,7 +338,9 @@ struct tile_window_with_static_distribution
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, offset, bool_constant<oob_conditional_check>{});
|
||||
bottom_tensor_thread_coord,
|
||||
linear_off,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
@@ -514,11 +533,13 @@ struct tile_window_with_static_distribution
|
||||
template <typename LdsTileWindow_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool static_move_ys = false,
|
||||
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>>>>
|
||||
CK_TILE_DEVICE void async_load_with_offset(index_t offset,
|
||||
LdsTileWindow_&& lds_tile,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<static_move_ys> = {}) const
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
@@ -531,7 +552,7 @@ struct tile_window_with_static_distribution
|
||||
const auto window_origin = lds_tile.get_window_origin();
|
||||
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
|
||||
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
|
||||
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
|
||||
auto lds_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
@@ -543,22 +564,51 @@ struct tile_window_with_static_distribution
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// Use precomputed window origin
|
||||
constexpr auto idx_ys_offset = [&]() {
|
||||
constexpr auto idx_off_ys = SFC_Ys::get_step_between(number<0>{}, iAccess);
|
||||
constexpr auto adapter_ys_offset = make_tensor_adaptor_coordinate(
|
||||
StaticTileDistribution_{}.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(array<index_t, Base::NDimP>{0},
|
||||
to_array<index_t, idx_off_ys.size()>(idx_off_ys)));
|
||||
return adapter_ys_offset.get_bottom_index();
|
||||
}();
|
||||
const auto lds_ys_offset = [&]() {
|
||||
if constexpr(static_move_ys)
|
||||
{
|
||||
const auto coord_ys_offset =
|
||||
make_tensor_coordinate(tensor_descriptor, idx_ys_offset);
|
||||
return coord_ys_offset.get_offset();
|
||||
}
|
||||
else
|
||||
return 0;
|
||||
}();
|
||||
|
||||
// Use precomputed window origin & tensor descriptor
|
||||
auto lds_bottom_tensor_thread_idx =
|
||||
window_origin + window_adaptor_warp_coord.get_bottom_index();
|
||||
|
||||
// Use precomputed tensor descriptor
|
||||
const auto lds_coord =
|
||||
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
|
||||
|
||||
// Calculate SMEM address using base pointer
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr +
|
||||
lds_coord.get_offset() / Traits::PackedSize +
|
||||
lds_ys_offset / Traits::PackedSize;
|
||||
|
||||
const auto dram_ys_offset = [&]() {
|
||||
if constexpr(static_move_ys)
|
||||
{
|
||||
const auto coord_ys_offset = make_tensor_coordinate(
|
||||
this->get_bottom_tensor_view().get_tensor_descriptor(), idx_ys_offset);
|
||||
return coord_ys_offset.get_offset();
|
||||
}
|
||||
else
|
||||
return 0;
|
||||
}();
|
||||
|
||||
// Write into bottom tensor
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
offset,
|
||||
offset + dram_ys_offset,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
// Move thread coordinate if not last access
|
||||
@@ -569,11 +619,15 @@ struct tile_window_with_static_distribution
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
if constexpr(!static_move_ys)
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord,
|
||||
bottom_tensor_thread_coord,
|
||||
idx_diff_ps_ys);
|
||||
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys);
|
||||
if constexpr(!static_move_ys)
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user