[CK_TILE] Add Flatmm MX FP8 (#3208)

* Use async for flatmm mxfp4

* Fix preshuffle

* Add flatmm mxfp8

* Thanks, Copilot

* Thanks Copilot again~

[ROCm/composable_kernel commit: 47e2ed838e]
This commit is contained in:
Yi DING
2025-11-20 10:35:15 +08:00
committed by GitHub
parent 158fec303c
commit 0d9f230577
17 changed files with 698 additions and 595 deletions

View File

@@ -21,7 +21,9 @@ if(has_supported_gpu)
add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp ${EXAMPLE_MX_FLATMM_FILES})
target_include_directories(tile_example_mx_flatmm PRIVATE mxgemm)
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
# ... because they are auto-generated
set(EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template)
set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)

View File

@@ -136,7 +136,7 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run MXFP4_Flatmm kernel " //
std::cout << "Run " << ck_tile::gemm_prec_str<ADataType, BDataType>() << " Flatmm kernel " //
<< " M = " << M << " N = " << N << " K = " << K << " StrideA = " << stride_A
<< " StrideB = " << stride_B << " StrideC = " << stride_C << " : " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
@@ -172,42 +172,47 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
template <class FlatmmConfig, class IterSrc, class IterDst>
void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K)
template <ck_tile::index_t N_Warp_Tile, typename dtype>
auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
{
int KPack = 16;
int NLane = FlatmmConfig::N_Warp_Tile;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
auto src_lengths = src.get_lengths();
const int K = src_lengths[0];
const int N = src_lengths[1];
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
int KPack = 16 * packed_size; // fp4:32 or fp8:16
int NLane = N_Warp_Tile;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K_pk; ++k)
for(int k = 0; k < K; k += packed_size)
{
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int k0 = k / (KLane * KPack);
int tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K_pk + k];
shuffled(outputIndex) = src(k, n);
}
}
return shuffled;
}
template <class FlatmmConfig, bool KLast, typename Src>
auto preShuffleScale(Src& src)
template <class FlatmmConfig, bool KLast, typename dtype>
auto preShuffleScale(ck_tile::HostTensor<dtype>& src)
{
using dtype = typename Src::Data::value_type;
auto src_lengths = src.get_lengths();
const auto MN = KLast ? src_lengths[0] : src_lengths[1];
const auto K = KLast ? src_lengths[1] : src_lengths[0];
@@ -261,7 +266,6 @@ auto preShuffleScale(Src& src)
#include "run_mx_flatmm.inc"
template <typename FlatmmConfig>
int run_mx_flatmm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -278,24 +282,31 @@ int run_mx_flatmm_example(int argc, char* argv[])
if(a_layout == "R" && b_layout == "C")
{
if(mx_prec == "fp4xfp4")
if(mx_prec == "fp4" || mx_prec == "fp4xfp4")
{
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
ck_tile::fp16_t,
FlatmmConfig,
MXfp4_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only non-persistent kernels are supported currently!");
}
else if(mx_prec == "fp6xfp6")
else if(mx_prec == "fp6" || mx_prec == "fp6xfp6")
{
throw std::runtime_error("Only support fp4xfp4 now!");
throw std::runtime_error("fp6xfp6 is not supported.");
}
else if(mx_prec == "fp8xfp8")
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
{
throw std::runtime_error("Only support fp4xfp4 now!");
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::fp16_t,
MXfp8_FlatmmConfig16,
false>(argc, argv, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else
{
@@ -306,7 +317,6 @@ int run_mx_flatmm_example(int argc, char* argv[])
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
return -1;
}
int main(int argc, char* argv[])
@@ -319,7 +329,7 @@ int main(int argc, char* argv[])
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
return run_mx_flatmm_example<MXfp4_FlatmmConfig16>(argc, argv);
return run_mx_flatmm_example(argc, argv);
}
else if(warp_tile == 1)
{

View File

@@ -12,4 +12,87 @@
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "mxfp4_flatmm.hpp"
// GEMM config with 16x16 warp tile
struct MXfp4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 512;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
struct MXfp8_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise,
bool Splitk,
bool HasHotLoop,
ck_tile::TailNumber TailNum>
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s);

View File

@@ -1,24 +1,29 @@
function(mx_flatmm_instance_generate FILE_LIST)
set(FLATMM_CONFIG MXfp4_FlatmmConfig16)
set(A_DATA_TYPE FP4)
set(B_DATA_TYPE FP4)
set(C_DATA_TYPE FP16)
set(A_LAYOUT ROW)
set(B_LAYOUT COL)
set(C_LAYOUT ROW)
set(FLATMM_CONFIG_FP4 "MXfp4_FlatmmConfig16")
set(FLATMM_CONFIG_FP8 "MXfp8_FlatmmConfig16")
# foreach(PERSISTENT false true)
# TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions.
foreach(PERSISTENT false)
foreach(SPLIT_K false true)
foreach(HAS_HOT_LOOP false true)
foreach(TAIL_NUMBER ODD EVEN)
set(KERNEL_FILE mxgemm/mx_flatmm_instance_${PERSISTENT}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp)
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in
${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}
@ONLY)
list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE})
foreach(DATA_TYPE FP4 FP8)
set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}})
set(A_DATA_TYPE ${DATA_TYPE})
set(B_DATA_TYPE ${DATA_TYPE})
foreach(SPLIT_K false true)
foreach(HAS_HOT_LOOP false true)
foreach(TAIL_NUMBER ODD EVEN)
set(KERNEL_FILE mxgemm/mx_flatmm_instance_${PERSISTENT}_${DATA_TYPE}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp)
string(TOLOWER ${KERNEL_FILE} KERNEL_FILE)
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in
${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}
@ONLY)
list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE})
endforeach()
endforeach()
endforeach()
endforeach()

View File

@@ -18,6 +18,7 @@
// clang-format on
using FP4 = ck_tile::pk_fp4_t;
using FP8 = ck_tile::fp8_t;
using FP16 = ck_tile::fp16_t;
using BF16 = ck_tile::bf16_t;

View File

@@ -1,60 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
// GEMM config with 16x16 warp tile
struct MXfp4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 512;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise,
bool Splitk,
bool HasHotLoop,
ck_tile::TailNumber TailNum>
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s);

View File

@@ -88,10 +88,7 @@ int run_mx_flatmm_with_layouts(int argc,
throw std::runtime_error("wrong! Unexpected init_method");
}
ck_tile::HostTensor<BDataType> b_shuffled_host(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffled_host.begin(), N, K);
const auto b_shuffled_host = preShuffleWeight<FlatmmConfig::N_Warp_Tile>(b_origin_host);
const auto scale_a_shuffled = preShuffleScale<FlatmmConfig, true>(scale_a);
const auto scale_b_shuffled = preShuffleScale<FlatmmConfig, false>(scale_b);

View File

@@ -41,6 +41,8 @@ using long_number = constant<v>;
template <bool b>
using bool_constant = constant<b>;
using true_type = bool_constant<true>;
using false_type = bool_constant<false>;
#define CK_TILE_LEFT_UNARY_OP(OP) \
template <auto x> \

View File

@@ -21,9 +21,10 @@ namespace ck_tile {
template <typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<TileWindow_>>>
typename offset_t,
typename = std::enable_if_t<std::is_class_v<TileWindow_>>>
CK_TILE_DEVICE auto load_tile_with_offset(const TileWindow_& tile_window,
index_t offset,
offset_t offset,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
@@ -67,11 +68,12 @@ template <typename DistributedTensor_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
typename offset_t,
typename = std::enable_if_t<std::is_class_v<std::remove_cv_t<DistributedTensor_>> &&
std::is_class_v<TileWindow_>>>
CK_TILE_DEVICE auto load_tile_with_offset(DistributedTensor_& dst_tile,
const TileWindow_& tile_window,
index_t offset,
offset_t offset,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
@@ -147,29 +149,31 @@ template <typename LdsTileWindow_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool static_move_ys = false,
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>> &&
std::is_class_v<TileWindow_>>>
CK_TILE_DEVICE auto async_load_tile_with_offset(LdsTileWindow_&& lds_tile,
CK_TILE_DEVICE void async_load_tile_with_offset(LdsTileWindow_&& lds_tile,
const TileWindow_& tile_window,
index_t offset,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
number<i_access> = {},
bool_constant<oob_conditional_check> occ = {},
bool_constant<static_move_ys> smy = {})
{
return tile_window.async_load_with_offset(
offset, lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
tile_window.async_load_with_offset(offset, lds_tile, number<i_access>{}, occ, smy);
}
template <typename LdsTileWindow_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
bool oob_conditional_check = true,
bool static_move_ys = false>
CK_TILE_DEVICE void async_load_tile(LdsTileWindow_&& lds_tile,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
number<i_access> = {},
bool_constant<oob_conditional_check> occ = {},
bool_constant<static_move_ys> smy = {})
{
return async_load_tile_with_offset(
lds_tile, tile_window, 0, number<i_access>{}, bool_constant<oob_conditional_check>{});
async_load_tile_with_offset(lds_tile, tile_window, 0, number<i_access>{}, occ, smy);
}
template <typename LdsTileWindow_,
@@ -177,19 +181,19 @@ template <typename LdsTileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
CK_TILE_DEVICE void async_load_tile_raw(LdsTileWindow_&& lds_tile,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
CK_TILE_DEVICE void async_load_fence(index_t cnt = 0)
{
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}

View File

@@ -166,8 +166,8 @@ struct tensor_view
{
return buf_.template async_get<X>(
smem,
coord.get_offset() / PackedSize,
linear_offset / PackedSize,
coord.get_offset() / PackedSize + linear_offset / PackedSize,
0, // linear_offset need to be imm and is not supported currently
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
bool_constant<oob_conditional_check>{});
}

View File

@@ -156,8 +156,10 @@ struct tile_window_with_static_distribution
0, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_with_offset(index_t offset,
template <index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
typename offset_t = index_t>
CK_TILE_DEVICE auto load_with_offset(offset_t offset,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
@@ -291,14 +293,16 @@ struct tile_window_with_static_distribution
0, dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor,
template <typename DataType,
typename StaticTileDistribution,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
typename = std::enable_if_t<std::is_class_v<std::remove_cv_t<DistributedTensor>>>>
CK_TILE_DEVICE auto load_with_offset(index_t offset,
DistributedTensor& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
typename offset_t>
CK_TILE_DEVICE void load_with_offset( //
offset_t offset,
static_distributed_tensor<DataType, StaticTileDistribution>& dst_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = typename Base::Traits;
using vector_t = typename Traits::vector_t;
@@ -306,6 +310,19 @@ struct tile_window_with_static_distribution
constexpr auto tile_dstr = typename Base::TileDstr{};
const index_t linear_off = [&]() {
if constexpr(std::is_integral_v<offset_t>)
return offset;
else if constexpr(is_constant_v<offset_t>)
return offset_t::value;
else
{
auto bottom_tensor_idx_off = to_multi_index(offset_t{});
auto bottom_tensor_coord_off = make_tensor_coordinate(
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off);
return bottom_tensor_coord_off.get_offset();
}
}();
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
@@ -321,7 +338,9 @@ struct tile_window_with_static_distribution
// read from bottom tensor
const vector_t vec_value =
this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, offset, bool_constant<oob_conditional_check>{});
bottom_tensor_thread_coord,
linear_off,
bool_constant<oob_conditional_check>{});
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
@@ -514,11 +533,13 @@ struct tile_window_with_static_distribution
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
bool static_move_ys = false,
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>>>>
CK_TILE_DEVICE void async_load_with_offset(index_t offset,
LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
bool_constant<oob_conditional_check> = {},
bool_constant<static_move_ys> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
@@ -531,7 +552,7 @@ struct tile_window_with_static_distribution
const auto window_origin = lds_tile.get_window_origin();
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
auto lds_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
@@ -543,22 +564,51 @@ struct tile_window_with_static_distribution
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// Use precomputed window origin
constexpr auto idx_ys_offset = [&]() {
constexpr auto idx_off_ys = SFC_Ys::get_step_between(number<0>{}, iAccess);
constexpr auto adapter_ys_offset = make_tensor_adaptor_coordinate(
StaticTileDistribution_{}.get_ps_ys_to_xs_adaptor(),
container_concat(array<index_t, Base::NDimP>{0},
to_array<index_t, idx_off_ys.size()>(idx_off_ys)));
return adapter_ys_offset.get_bottom_index();
}();
const auto lds_ys_offset = [&]() {
if constexpr(static_move_ys)
{
const auto coord_ys_offset =
make_tensor_coordinate(tensor_descriptor, idx_ys_offset);
return coord_ys_offset.get_offset();
}
else
return 0;
}();
// Use precomputed window origin & tensor descriptor
auto lds_bottom_tensor_thread_idx =
window_origin + window_adaptor_warp_coord.get_bottom_index();
// Use precomputed tensor descriptor
const auto lds_coord =
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
// Calculate SMEM address using base pointer
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr +
lds_coord.get_offset() / Traits::PackedSize +
lds_ys_offset / Traits::PackedSize;
const auto dram_ys_offset = [&]() {
if constexpr(static_move_ys)
{
const auto coord_ys_offset = make_tensor_coordinate(
this->get_bottom_tensor_view().get_tensor_descriptor(), idx_ys_offset);
return coord_ys_offset.get_offset();
}
else
return 0;
}();
// Write into bottom tensor
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
bottom_tensor_thread_coord,
offset,
offset + dram_ys_offset,
bool_constant<oob_conditional_check>{});
// Move thread coordinate if not last access
@@ -569,11 +619,15 @@ struct tile_window_with_static_distribution
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
idx_diff_ys);
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
if constexpr(!static_move_ys)
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord,
bottom_tensor_thread_coord,
idx_diff_ps_ys);
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys);
if constexpr(!static_move_ys)
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys);
}
});
});

View File

@@ -432,6 +432,12 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
a_m_k_scaled(m, k) = a_f4_lo * a_scale;
a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
}
else
{
a_m_k_scaled(m, k) =
ck_tile::type_convert<AccDataType>((a_m_k(m, k))) *
ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
}
}
}

View File

@@ -19,6 +19,7 @@ template <> struct typeToStr<fp8_t> { static constexpr const char * name = "fp8"
template <> struct typeToStr<bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct typeToStr<int8_t> { static constexpr const char * name = "int8"; };
template <> struct typeToStr<pk_int4_t> { static constexpr const char * name = "pk_int4"; };
template <> struct typeToStr<pk_fp4_t> { static constexpr const char * name = "pk_fp4"; };
template <memory_operation_enum MemOp> struct memOpToStr;
template <> struct memOpToStr<memory_operation_enum::set> { static constexpr const char * name = "set"; };

View File

@@ -143,16 +143,24 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
}
}();
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock;
constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
const index_t kFlatKBlocks = kargs.K / kKPerBlock;
const index_t kFlatN = kargs.N / kNWarpTile;
const auto& b_flat_tensor_view = [&]() {
static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0,
"wrong! vector size for B tensor");
auto&& naive_desc = make_naive_tensor_descriptor_packed(
make_tuple(kFlatN, kFlatKBlocks, number<flatKPerBlock>{}));
auto&& desc = transform_tensor_descriptor(
naive_desc,
make_tuple(make_pass_through_transform(kFlatN),
make_merge_transform_v3_division_mod(
make_tuple(kFlatKBlocks, number<flatKPerBlock>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(b_flat_ptr, desc);
}();
const auto& ds_tensor_view = generate_tuple(

View File

@@ -44,7 +44,10 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
else if(TailNumber::Odd == tail_num)
return TailHandler<DispatchHotloop, TailNumber::Odd>(run_func, has_hot_loop);
else
{
assert(("Wrong TailNumber!", false));
return decltype(TailHandler<>(run_func, true, TailNumber::Even)){};
}
}
};

View File

@@ -43,7 +43,7 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
static constexpr int NXdlPack = 2; // it's fixed for fp4
static constexpr int KXdlPack = 2;
// static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp * KXdlPack;
static constexpr index_t flatKPerWarp = 64 * ContinuousKPerThread;
static constexpr index_t flatKPerWarp = get_warp_size() * ContinuousKPerThread;
};
template <typename Problem, typename PipelinePolicy = MXF4FlatmmPipelineAgBgCrPolicy>
@@ -122,9 +122,10 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
static constexpr index_t MXdlPack = Problem::MXdlPack;
static constexpr index_t NXdlPack = Problem::NXdlPack;
static constexpr index_t KXdlPack = Problem::KXdlPack;
static constexpr index_t MXdlPack = Problem::MXdlPack;
static constexpr index_t NXdlPack = Problem::NXdlPack;
static constexpr index_t KXdlPack = Problem::KXdlPack;
static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK;
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize;
@@ -138,25 +139,25 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static constexpr index_t mfma_per_wg = 1; // 950 only
static constexpr index_t dsread_per_wg =
WG::kM * WG::kK * sizeof(ADataType) / APackedSize / WaveSize / Problem::VectorLoadSize;
static_assert((WG::kM * WG::kK * sizeof(ADataType) / APackedSize / WaveSize) %
Problem::VectorLoadSize ==
0);
static constexpr index_t dsread_per_wg = WG::kM * WG::kK / AK1 / WaveSize;
static_assert((WG::kM * WG::kK) % (AK1 * WaveSize) == 0);
static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp;
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
static constexpr index_t dswrite_num_perK = dsread_num_perK / NWarp;
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
static constexpr index_t Aload_num_perK = dswrite_num_perK;
static constexpr index_t Aload_rep = dswrite_rep;
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize;
static constexpr index_t ScaleBload_K1 = NXdlPack * KXdlPack; // fixed for fp4
static constexpr index_t Bload_num = Bload_num_perK * KIterPerWarp;
static constexpr index_t ScaleBload_num =
kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 / WaveSize;
static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num;
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
kNPerBlock * kKPerBlock / NWarp / ScaleGranularityK / NXdlPack / KXdlPack / WaveSize;
static constexpr index_t ScaleAload_num =
kMPerBlock * kKPerBlock / MWarp / ScaleGranularityK / MXdlPack / KXdlPack / WaveSize;
// static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num;
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg;
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
@@ -219,7 +220,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
{
if(inst_order[inst_idx + r * mfma_perM_perK] == 1)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if(inst_order[inst_idx + r * mfma_perM_perK] == 2)
{
@@ -234,7 +235,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
{
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2)
{
@@ -470,18 +471,16 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
}
template <typename ADramBlockWindowTmp,
typename AElementFunction,
typename BFlatBlockWindowTmp,
typename ScaleADramBlockWindowTmp,
typename ScaleBDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_copy_dram_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const ScaleADramBlockWindowTmp& scale_a_window,
const ScaleBDramBlockWindowTmp& scale_b_window,
index_t num_loop,
void* __restrict__ p_smem_ping,
void* __restrict__ p_smem_pong) const
{
#ifndef __gfx950__
static_assert(false, "Only gfx950 is supported for MXFP4 flatmm pipeline now.");
@@ -495,9 +494,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
const index_t iMWarp = get_warp_id() / NWarp;
// const index_t iNWarp = get_warp_id() % NWarp;
// constexpr auto MIter_2nd_last = max(0, MIterPerWarp - 2);
static_assert(NWarp == 4);
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
@@ -506,6 +504,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
auto a_dram_window =
make_tile_window(PipelinePolicy::template MakeMXFP4_AAsyncLoadDramDescriptor<Problem>(
a_copy_dram_window_tmp.get_bottom_tensor_view()),
a_copy_dram_window_tmp.get_window_lengths(),
a_copy_dram_window_tmp.get_window_origin(),
PipelinePolicy::template MakeMXFP4_ADramTileDistribution<Problem>());
__builtin_amdgcn_sched_barrier(0);
// A tile in LDS
@@ -520,93 +525,51 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
auto a_lds_block_pong =
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
auto a_copy_lds_window_ping =
make_tile_window(a_lds_block_ping,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
auto a_copy_lds_window_pong =
make_tile_window(a_lds_block_pong,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
auto a_store_lds_window_ping = make_tile_window(
a_lds_block_ping, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
auto a_store_lds_window_pong = make_tile_window(
a_lds_block_pong, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// ping-pong window for A LDS
auto a_warp_window_ping_tmp =
auto a_warp_window_ping =
make_tile_window(a_lds_block_ping,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
{0, 0},
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
auto a_warp_window_pong_tmp =
auto a_warp_window_pong =
make_tile_window(a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
{0, 0},
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows_ping;
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows_pong;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
auto packed_m_idx = mIter / number<MXdlPack>{};
auto packed_m_rank = mIter % number<MXdlPack>{};
move_tile_window(
a_warp_windows_ping(mIter)(kIter),
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
kIter * KPerBlockPerIter});
move_tile_window(
a_warp_windows_pong(mIter)(kIter),
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
kIter * KPerBlockPerIter});
});
});
// Block GEMM
auto block_flatmm = BlockFlatmm();
// Acc register tile
auto c_block_tile = block_flatmm.MakeCBlockTile();
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution<Problem>();
auto b_flat_dram_window = make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
using MXFP4_B_Buffer = decltype(load_tile(b_flat_dram_window));
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
using V4UInt_B_Buffer = thread_buffer<uint32_t, 4>;
union UnionBuf
{
V4UInt_B_Buffer u = 0;
MXFP4_B_Buffer mxfp4;
} ub;
// pingpong buffer for B
auto b_flat_dram_windows = generate_tuple(
[&](auto nIter) {
constexpr auto packed_n_idx = nIter / number<NXdlPack>{};
constexpr auto packed_n_rank = nIter % number<NXdlPack>{};
auto window_i = make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution<Problem>());
move_tile_window(
window_i,
{number<packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank>{},
number<0>{}});
return window_i;
},
number<NIterPerWarp>{});
statically_indexed_array<
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
statically_indexed_array<decltype(load_tile(b_flat_dram_windows(I0))), KIterPerWarp>,
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<statically_indexed_array<V4UInt_B_Buffer, KIterPerWarp>,
NIterPerWarp>
b_warp_tensor_ping;
statically_indexed_array<statically_indexed_array<V4UInt_B_Buffer, KIterPerWarp>,
NIterPerWarp>
b_warp_tensor_pong;
b_warp_tensor_ping, b_warp_tensor_pong;
// pingpong buffer for Scale A and Scale B
auto scale_a_dram_window = make_tile_window(
@@ -649,29 +612,24 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
NIterPerWarp / NXdlPack>
scale_b_tile_tensor_pong;
auto async_load_tile_ = [](auto lds, auto dram) {
async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{});
};
// HEAD
// Prefetch A0
auto a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
async_load_tile_(a_store_lds_window_ping, a_dram_window);
move_tile_window(a_dram_window, {0, kKPerBlock});
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
auto packed_n_idx = nIter / number<NXdlPack>{};
auto packed_n_rank = nIter % number<NXdlPack>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_ping(nIter)(kIter) = ub.u;
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
});
// move B window to next flat K
move_tile_window(b_flat_dram_windows(nIter), {0, KIterPerWarp * KFlatPerBlockPerIter});
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, KIterPerWarp * KFlatPerBlockPerIter});
// prefetch Scale A
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
@@ -700,71 +658,40 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
// move Scale B window to next K
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
// A_Lds_TileDist may differ with ADramTileDistribution
auto a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
__builtin_amdgcn_sched_barrier(0);
// Prefetch A1
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
if constexpr(HasHotLoop || TailNum == TailNumber::Even)
{
async_load_tile_(a_store_lds_window_pong, a_dram_window);
move_tile_window(a_dram_window, {0, kKPerBlock});
}
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
clear_tile(c_block_tile);
block_sync_lds();
using MXFP4_A_Buffer_ping =
decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{})));
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
using V4UInt_A_Buffer = thread_buffer<uint32_t, 4>;
union UnionBuf_A_ping
{
V4UInt_A_Buffer u = 0;
MXFP4_A_Buffer_ping mxfp4;
} ua_ping;
using MXFP4_A_Buffer_pong =
decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{})));
union UnionBuf_A_pong
{
V4UInt_A_Buffer u = 0;
MXFP4_A_Buffer_pong mxfp4;
} ua_pong;
statically_indexed_array<decltype(load_tile(a_warp_window_pong)), m_preload> a_warp_tensor;
// preload A00,A10... from lds
statically_indexed_array<V4UInt_A_Buffer, m_preload> a_warp_tensor;
s_waitcnt_barrier</*vmcnt*/ dswrite_num_perK>();
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) = ua_ping.u;
a_warp_tensor(loadIter) = load_tile_with_offset(
a_warp_window_ping, tuple<number<mIter * WG::kM>, number<kIter * WG::kK>>{});
});
__builtin_amdgcn_sched_barrier(0);
// MAIN LOOP
index_t iCounter = (num_loop - 1) / 2;
while(iCounter > 0)
{
auto main_body_implx2 = [&]() mutable {
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
auto packed_n_idx = nIter / number<NXdlPack>{};
auto packed_n_rank = nIter % number<NXdlPack>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_pong(nIter)(kIter) = ub.u;
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
if constexpr(kIter == KIterPerWarp - 1)
move_tile_window(b_flat_dram_windows(nIter),
{0, BlockGemmShape::flatKPerBlock});
});
});
@@ -791,15 +718,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
});
// Prefill A(2i+1)
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
// Prefetch A(2i+2)
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
@@ -807,30 +725,26 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<m_iter, n_iter>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
UnionBuf_A_ping ua_compute;
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
UnionBuf ub_compute;
ub_compute.u =
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl);
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
ua_compute.mxfp4,
ub_compute.mxfp4,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0],
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
@@ -838,68 +752,60 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
nIter_pack * NXdlPack + inxdl>{},
merge_sequences(sequence<m_iter, n_iter>{},
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
// preload next A from lds
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
(kIter_pack * KXdlPack + ikxdl) * 2 +
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
m_preload;
constexpr auto addr =
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_ping.mxfp4 = load_tile(
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
}
// barrier
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
{
block_sync_lds();
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
a_warp_window_ping,
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
}
});
});
});
});
});
// barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished
s_waitcnt< // vmcnt
Bload_num + ScaleAload_num + ScaleBload_num>();
block_sync_lds();
// Prefetch A(2i+2)
async_load_tile_(a_store_lds_window_ping, a_dram_window);
move_tile_window(a_dram_window, {0, kKPerBlock});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
// preload A(2i+1)
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
a_warp_tensor(loadIter) = load_tile_with_offset(
a_warp_window_pong, tuple<number<mIter * WG::kM>, number<kIter * WG::kK>>{});
});
HotLoopScheduler();
// Next K
////////////////////////////// Next K //////////////////////////////
// prefetch B(2i+2)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
auto packed_n_idx = nIter / number<NXdlPack>{};
auto packed_n_rank = nIter % number<NXdlPack>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_ping(nIter)(kIter) = ub.u;
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
if constexpr(kIter == KIterPerWarp - 1)
move_tile_window(b_flat_dram_windows(nIter),
{0, BlockGemmShape::flatKPerBlock});
});
});
@@ -926,15 +832,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
});
// Prefill A(2i+2)
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
// Prefetch A(2i+3)
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i+1
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
@@ -953,20 +850,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
UnionBuf_A_pong ua_compute;
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
UnionBuf ub_compute;
ub_compute.u =
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl);
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
ua_compute.mxfp4,
ub_compute.mxfp4,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
@@ -988,39 +878,47 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_pong.mxfp4 = load_tile(
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
}
// barrier
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
{
block_sync_lds();
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
a_warp_window_pong,
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
}
});
});
});
});
});
// barrier as ds_load A(2i + 1) and buffer_load_lds A(2i + 2) finished
s_waitcnt< // vmcnt
Bload_num + ScaleAload_num + ScaleBload_num>();
block_sync_lds();
// Prefetch A(2i+3)
async_load_tile_(a_store_lds_window_pong, a_dram_window);
move_tile_window(a_dram_window, {0, kKPerBlock});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
// preload A(2i+2)
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) = ua_ping.u; // reload a_warp_tensor with ping buffer
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
a_warp_tensor(loadIter) = load_tile_with_offset(
a_warp_window_ping, tuple<number<mIter * WG::kM>, number<kIter * WG::kK>>{});
});
HotLoopScheduler();
};
iCounter--;
if constexpr(HasHotLoop)
{
index_t iCounter = (num_loop - 1) / 2;
do
{
main_body_implx2();
iCounter--;
} while(iCounter > 0);
}
// TAIL
@@ -1029,18 +927,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
// prefetch B(loopK)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
auto packed_n_idx = nIter / number<NXdlPack>{};
auto packed_n_rank = nIter % number<NXdlPack>{};
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
kIter * KFlatPerBlockPerIter});
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
b_warp_tensor_pong(nIter)(kIter) = ub.u;
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
b_flat_dram_windows(nIter),
make_tuple(number<0>{}, number<kIter * KFlatPerBlockPerIter>{}));
});
});
@@ -1055,7 +944,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
});
});
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
@@ -1067,10 +955,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
});
// Prefill A(loopK)
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
// GEMM loopK-1
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
@@ -1089,20 +973,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
UnionBuf_A_ping ua_compute;
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
UnionBuf ub_compute;
ub_compute.u =
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl);
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
ua_compute.mxfp4,
ub_compute.mxfp4,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
@@ -1124,30 +1001,28 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_ping.mxfp4 = load_tile(
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
}
// barrier
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
{
block_sync_lds();
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
a_warp_window_ping,
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
}
});
});
});
});
});
// barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished
s_waitcnt< // vmcnt
Bload_num + ScaleAload_num + ScaleBload_num>();
block_sync_lds();
// preload A(2i+1)
static_for<0, m_preload, 1>{}([&](auto loadIter) {
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
constexpr auto mIter = loadIter % MXdlPack;
constexpr auto kIter = loadIter / MXdlPack;
a_warp_tensor(loadIter) = load_tile_with_offset(
a_warp_window_pong, tuple<number<mIter * WG::kM>, number<kIter * WG::kK>>{});
});
Last2ndHotLoopScheduler();
@@ -1170,19 +1045,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
UnionBuf_A_pong ua_compute;
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
UnionBuf ub_compute;
ub_compute.u =
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl);
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
ua_compute.mxfp4,
ub_compute.mxfp4,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
@@ -1204,18 +1073,11 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_pong.mxfp4 = load_tile(
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
}
// barrier
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
{
block_sync_lds();
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
a_warp_window_pong,
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
}
});
});
@@ -1244,20 +1106,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
UnionBuf_A_ping ua_compute;
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
UnionBuf ub_compute;
ub_compute.u =
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl);
// warp GEMM
WG{}.template
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
c_warp_tensor,
ua_compute.mxfp4,
ub_compute.mxfp4,
a_warp_tensor(number<AwarpIter>{}),
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
kIter_pack * number<KXdlPack>{} + ikxdl),
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
.get_thread_buffer()[0], // scale A
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
@@ -1279,18 +1134,11 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
(nIter_pack == NIterPerWarp / NXdlPack - 1))
{
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
ua_ping.mxfp4 = load_tile(
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
}
// barrier
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
{
block_sync_lds();
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
constexpr auto AkIter = addr / 2 % 2;
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
a_warp_window_ping,
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
}
});
});
@@ -1299,32 +1147,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
});
LastHotLoopScheduler();
}
else
{
static_assert(false, "Wrong TailNum");
}
return c_block_tile;
}
template <typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename ScaleADramBlockWindowTmp,
typename ScaleBDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
const ScaleADramBlockWindowTmp& scale_a_flat_window_tmp,
const ScaleBDramBlockWindowTmp& scale_b_flat_window_tmp,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
{
return operator()(
a_dram_block_window_tmp,
[](const ADataType & a) { return a; },
b_flat_dram_block_window_tmp,
scale_a_flat_window_tmp,
scale_b_flat_window_tmp,
num_loop,
p_smem_ping,
p_smem_pong);
}
};
} // namespace ck_tile

View File

@@ -13,22 +13,139 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr index_t KBPerLoad = 32;
static constexpr index_t kDramLoadPackBytes = 128;
static constexpr int MXdlPack = 2;
static constexpr int NXdlPack = 2;
static constexpr int KXdlPack = 2;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
{
using namespace ck_tile;
static inline constexpr auto wg_attr_num_access =
std::is_same_v<remove_cvref_t<typename Problem::ADataType>, pk_fp4_t>
? WGAttrNumAccessEnum::Single
: WGAttrNumAccessEnum::Double;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static_assert(
sizeof(ADataType) * numeric_traits<BDataType>::PackedSize ==
sizeof(BDataType) * numeric_traits<ADataType>::PackedSize,
"sizeof(ADataType) / APackedSize must be equal to sizeof(BDataType) / BPackedSize!");
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher< //
ADataType,
BDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
wg_attr_num_access<Problem>>;
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< //
ADataType,
BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
}
template <typename Problem, typename TensorView>
CK_TILE_DEVICE static constexpr auto
MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view)
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
static_assert(MPerXdl == 16 && NPerXdl == 16);
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
const auto& naive_desc = naive_view.get_tensor_descriptor();
constexpr auto ndims = remove_cvref_t<decltype(naive_desc)>::get_num_of_dimension();
static_assert(ndims == 2, "only support 2D tensor");
const auto rows = naive_desc.get_length(number<0>{});
const auto cols = naive_desc.get_length(number<1>{});
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
const index_t K0 = cols / (K1 * K2);
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
constexpr index_t M1 = 4; // so that we can use imm offset to load lds
const index_t M0 = rows / M1;
const auto row_lens = make_tuple(M0, number<M1>{});
const auto desc_0 =
make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
const auto desc_1 = transform_tensor_descriptor(
desc_0,
make_tuple(make_pass_through_transform(M0),
make_xor_transform(make_tuple(number<M1>{}, number<K1>{})),
make_pass_through_transform(K0),
make_pass_through_transform(number<K2>{})),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
const auto desc = transform_tensor_descriptor( //
desc_1,
make_tuple(make_merge_transform_v3_division_mod(row_lens),
make_merge_transform_v3_division_mod(col_lens)),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// printf("A async load dram desc %d x %d: \n", desc.get_length(I0), desc.get_length(I1));
return tensor_view<typename TensorView::buffer_view,
remove_cvref_t<decltype(desc)>,
TensorView::DstInMemOp>{naive_view.buf_, desc};
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
constexpr index_t M2 = get_warp_size() / K1; // 8
constexpr index_t M1 = BlockSize / get_warp_size(); // 4
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>, // M0,K0,K2
sequence<0, 0, 2>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
static_assert(MPerXdl == 16 && NPerXdl == 16);
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
@@ -36,65 +153,70 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t KPack = GetSmemPackA<Problem>() * APackedSize;
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
constexpr index_t M3 = 4; // so that we can use imm offset to load lds
constexpr index_t M2 = get_warp_size() / K1 / M3; // 2
constexpr index_t M1 = MPerXdl / (M2 * M3); // 2
constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16
static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!");
constexpr index_t Pad = 4 * K2; // 4 * 32
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
make_tuple(number<M0>{},
number<M1>{},
number<K0>{},
number<M2>{},
number<M3>{},
number<K1>{},
number<K2>{}),
make_tuple(number<M1*(K0 * (M2 * M3 * K1 * K2) + (K0 - 1) * Pad)>{},
number<K0*(M2 * M3 * K1 * K2) + (K0 - 1) * Pad>{},
number<M2 * M3 * K1 * K2 + Pad>{},
number<M3 * K1 * K2>{},
number<K1 * K2>{},
number<K2>{},
number<1>{}),
number<K2>{},
number<1>{});
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
constexpr auto a_lds_block_desc_1 = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(
make_xor_transform(make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(M1),
make_pass_through_transform(K0),
make_pass_through_transform(M2),
make_xor_transform(make_tuple(number<M3>{}, number<K1>{})),
make_pass_through_transform(number<K2>{})),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4, 5>{},
sequence<6>{}),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4, 5>{},
sequence<6>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
a_lds_block_desc_1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<M0>{}, number<M1>{}, number<M2>{}, number<M3>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(number<K0>{}, number<K1>{}, number<K2>{}))),
make_tuple(sequence<0, 1, 3, 4>{}, sequence<2, 5, 6>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// return a_lds_block_desc_permuted;
return a_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution()
{
@@ -105,20 +227,31 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
constexpr int M_warps = TileShape::BlockWarps::at(number<0>{});
constexpr int N_warps = TileShape::BlockWarps::at(number<1>{});
constexpr int M_Lane = TileShape::WarpTile::at(I0);
constexpr int M_Lane = TileShape::WarpTile::at(I0); // 16
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I0); // 4
constexpr int K_Lane = 64 / M_Lane; // 4
constexpr int K1 = TileShape::WarpTile::at(I2) / K_Lane; // 32
constexpr int K_Thread = TileShape::WarpTile::at(I2) / K_Lane; // 32
constexpr index_t num_access_v = static_cast<index_t>(wg_attr_num_access<Problem>);
constexpr int K1 = K_Thread / num_access_v; // 16
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<N_warps>,
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<K_Lane, K1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<0, 2>>,
sequence<2>,
sequence<1>>{});
std::conditional_t<
num_access_v == 1,
tile_distribution_encoding<
sequence<N_warps>,
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<K_Lane, K1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<0, 2>>,
sequence<2>,
sequence<1>>,
tile_distribution_encoding< //
sequence<N_warps>,
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<num_access_v, K_Lane, K1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<1, 2>>,
sequence<2, 2>,
sequence<0, 2>>>{});
}
template <typename Problem>
@@ -132,25 +265,36 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t K1 = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t K0 = KWavePerBlk;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
constexpr index_t kKPerThread = 32;
constexpr index_t num_access_v = static_cast<index_t>(wg_attr_num_access<Problem>);
constexpr index_t K2 = kKPerThread / num_access_v;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>,
sequence<KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<2>>, // which direction
tuple<sequence<0, 0, 0>, sequence<1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<2>>{});
std::conditional_t< //
num_access_v == 1,
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
sequence<K0, K1, K2>>, // 1 64 32
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 0>, sequence<1>>,
sequence<2>,
sequence<2>>,
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
sequence<num_access_v, K0, K1, K2>>, // 2 1 64 16
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 1>, sequence<2>>,
sequence<2, 2>,
sequence<0, 3>>>{});
}
template <typename Problem>
@@ -270,6 +414,21 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
sequence<2>,
sequence<1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
return sizeof(ADataType) *
MakeMXFP4_ALdsBlockDescriptor<Problem>().get_element_space_size() / APackedSize;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return GetSmemSizeA<Problem>();
}
};
} // namespace ck_tile