[CK_TILE] Add Flatmm MX FP8 (#3208)

* Use async for flatmm mxfp4

* Fix preshuffle

* Add flatmm mxfp8

* Thanks, Copilot

* Thanks Copilot again~
This commit is contained in:
Yi DING
2025-11-20 10:35:15 +08:00
committed by GitHub
parent 4e49e0228b
commit 47e2ed838e
17 changed files with 698 additions and 595 deletions

View File

@@ -41,6 +41,8 @@ using long_number = constant<v>;
template <bool b>
using bool_constant = constant<b>;
using true_type = bool_constant<true>;
using false_type = bool_constant<false>;
#define CK_TILE_LEFT_UNARY_OP(OP) \
template <auto x> \

View File

@@ -21,9 +21,10 @@ namespace ck_tile {
template <typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<TileWindow_>>>
typename offset_t,
typename = std::enable_if_t<std::is_class_v<TileWindow_>>>
CK_TILE_DEVICE auto load_tile_with_offset(const TileWindow_& tile_window,
index_t offset,
offset_t offset,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
@@ -67,11 +68,12 @@ template <typename DistributedTensor_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
typename offset_t,
typename = std::enable_if_t<std::is_class_v<std::remove_cv_t<DistributedTensor_>> &&
std::is_class_v<TileWindow_>>>
CK_TILE_DEVICE auto load_tile_with_offset(DistributedTensor_& dst_tile,
const TileWindow_& tile_window,
index_t offset,
offset_t offset,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
@@ -147,29 +149,31 @@ template <typename LdsTileWindow_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool static_move_ys = false,
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>> &&
std::is_class_v<TileWindow_>>>
CK_TILE_DEVICE auto async_load_tile_with_offset(LdsTileWindow_&& lds_tile,
CK_TILE_DEVICE void async_load_tile_with_offset(LdsTileWindow_&& lds_tile,
const TileWindow_& tile_window,
index_t offset,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
number<i_access> = {},
bool_constant<oob_conditional_check> occ = {},
bool_constant<static_move_ys> smy = {})
{
return tile_window.async_load_with_offset(
offset, lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
tile_window.async_load_with_offset(offset, lds_tile, number<i_access>{}, occ, smy);
}
template <typename LdsTileWindow_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
bool oob_conditional_check = true,
bool static_move_ys = false>
CK_TILE_DEVICE void async_load_tile(LdsTileWindow_&& lds_tile,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
number<i_access> = {},
bool_constant<oob_conditional_check> occ = {},
bool_constant<static_move_ys> smy = {})
{
return async_load_tile_with_offset(
lds_tile, tile_window, 0, number<i_access>{}, bool_constant<oob_conditional_check>{});
async_load_tile_with_offset(lds_tile, tile_window, 0, number<i_access>{}, occ, smy);
}
template <typename LdsTileWindow_,
@@ -177,19 +181,19 @@ template <typename LdsTileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
CK_TILE_DEVICE void async_load_tile_raw(LdsTileWindow_&& lds_tile,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
CK_TILE_DEVICE void async_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}

View File

@@ -166,8 +166,8 @@ struct tensor_view
{
return buf_.template async_get<X>(
smem,
coord.get_offset() / PackedSize,
linear_offset / PackedSize,
coord.get_offset() / PackedSize + linear_offset / PackedSize,
0, // linear_offset need to be imm and is not supported currently
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{});
}

View File

@@ -156,8 +156,10 @@ struct tile_window_with_static_distribution
0, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_with_offset(index_t offset,
template <index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
typename offset_t = index_t>
CK_TILE_DEVICE auto load_with_offset(offset_t offset,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
@@ -291,14 +293,16 @@ struct tile_window_with_static_distribution
0, dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor,
template <typename DataType,
typename StaticTileDistribution,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<std::remove_cv_t<DistributedTensor>>>>
CK_TILE_DEVICE auto load_with_offset(index_t offset,
DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
typename offset_t>
CK_TILE_DEVICE void load_with_offset( //
offset_t offset,
static_distributed_tensor<DataType, StaticTileDistribution>& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
@@ -306,6 +310,19 @@ struct tile_window_with_static_distribution
constexpr auto tile_dstr = typename Base::TileDstr{};
const index_t linear_off = [&]() {
if constexpr(std::is_integral_v<offset_t>)
return offset;
else if constexpr(is_constant_v<offset_t>)
return offset_t::value;
else
{
auto bottom_tensor_idx_off = to_multi_index(offset_t{});
auto bottom_tensor_coord_off = make_tensor_coordinate(
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off);
return bottom_tensor_coord_off.get_offset();
}
}();
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
@@ -321,7 +338,9 @@ struct tile_window_with_static_distribution
// read from bottom tensor
const vector_t vec_value =
this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, offset, bool_constant<oob_conditional_check>{});
bottom_tensor_thread_coord,
linear_off,
bool_constant<oob_conditional_check>{});
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
@@ -514,11 +533,13 @@ struct tile_window_with_static_distribution
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
bool static_move_ys = false,
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>>>>
CK_TILE_DEVICE void async_load_with_offset(index_t offset,
LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
bool_constant<oob_conditional_check> = {},
bool_constant<static_move_ys> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
@@ -531,7 +552,7 @@ struct tile_window_with_static_distribution
const auto window_origin = lds_tile.get_window_origin();
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
auto lds_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
@@ -543,22 +564,51 @@ struct tile_window_with_static_distribution
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// Use precomputed window origin
constexpr auto idx_ys_offset = [&]() {
constexpr auto idx_off_ys = SFC_Ys::get_step_between(number<0>{}, iAccess);
constexpr auto adapter_ys_offset = make_tensor_adaptor_coordinate(
StaticTileDistribution_{}.get_ps_ys_to_xs_adaptor(),
container_concat(array<index_t, Base::NDimP>{0},
to_array<index_t, idx_off_ys.size()>(idx_off_ys)));
return adapter_ys_offset.get_bottom_index();
}();
const auto lds_ys_offset = [&]() {
if constexpr(static_move_ys)
{
const auto coord_ys_offset =
make_tensor_coordinate(tensor_descriptor, idx_ys_offset);
return coord_ys_offset.get_offset();
}
else
return 0;
}();
// Use precomputed window origin & tensor descriptor
auto lds_bottom_tensor_thread_idx =
window_origin + window_adaptor_warp_coord.get_bottom_index();
// Use precomputed tensor descriptor
const auto lds_coord =
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
// Calculate SMEM address using base pointer
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr +
lds_coord.get_offset() / Traits::PackedSize +
lds_ys_offset / Traits::PackedSize;
const auto dram_ys_offset = [&]() {
if constexpr(static_move_ys)
{
const auto coord_ys_offset = make_tensor_coordinate(
this->get_bottom_tensor_view().get_tensor_descriptor(), idx_ys_offset);
return coord_ys_offset.get_offset();
}
else
return 0;
}();
// Write into bottom tensor
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
bottom_tensor_thread_coord,
offset,
offset + dram_ys_offset,
bool_constant<oob_conditional_check>{});
// Move thread coordinate if not last access
@@ -569,11 +619,15 @@ struct tile_window_with_static_distribution
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
if constexpr(!static_move_ys)
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord,
bottom_tensor_thread_coord,
idx_diff_ps_ys);
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys);
if constexpr(!static_move_ys)
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys);
}
});
});