mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Add Flatmm MX FP8 (#3208)
* Use async for flatmm mxfp4
* Fix preshuffle
* Add flatmm mxfp8
* Thanks, Copilot
* Thanks Copilot again~
[ROCm/composable_kernel commit: 47e2ed838e]
This commit is contained in:
@@ -21,7 +21,9 @@ if(has_supported_gpu)
|
||||
add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp ${EXAMPLE_MX_FLATMM_FILES})
|
||||
target_include_directories(tile_example_mx_flatmm PRIVATE mxgemm)
|
||||
|
||||
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
# ... because they are auto-generated
|
||||
set(EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template)
|
||||
set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS)
|
||||
|
||||
if(CK_USE_OCP_FP8)
|
||||
|
||||
@@ -136,7 +136,7 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run MXFP4_Flatmm kernel " //
|
||||
std::cout << "Run " << ck_tile::gemm_prec_str<ADataType, BDataType>() << " Flatmm kernel " //
|
||||
<< " M = " << M << " N = " << N << " K = " << K << " StrideA = " << stride_A
|
||||
<< " StrideB = " << stride_B << " StrideC = " << stride_C << " : " << ave_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
@@ -172,42 +172,47 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, class IterSrc, class IterDst>
|
||||
void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K)
|
||||
template <ck_tile::index_t N_Warp_Tile, typename dtype>
|
||||
auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = FlatmmConfig::N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
auto src_lengths = src.get_lengths();
|
||||
const int K = src_lengths[0];
|
||||
const int N = src_lengths[1];
|
||||
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
|
||||
int KPack = 16 * packed_size; // fp4:32 or fp8:16
|
||||
int NLane = N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
int K0 = K / (KLane * KPack);
|
||||
|
||||
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));
|
||||
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K_pk; ++k)
|
||||
for(int k = 0; k < K; k += packed_size)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
int k0 = k / (KLane * KPack);
|
||||
int tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * K_pk + k];
|
||||
shuffled(outputIndex) = src(k, n);
|
||||
}
|
||||
}
|
||||
return shuffled;
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, bool KLast, typename Src>
|
||||
auto preShuffleScale(Src& src)
|
||||
template <class FlatmmConfig, bool KLast, typename dtype>
|
||||
auto preShuffleScale(ck_tile::HostTensor<dtype>& src)
|
||||
{
|
||||
using dtype = typename Src::Data::value_type;
|
||||
auto src_lengths = src.get_lengths();
|
||||
const auto MN = KLast ? src_lengths[0] : src_lengths[1];
|
||||
const auto K = KLast ? src_lengths[1] : src_lengths[0];
|
||||
@@ -261,7 +266,6 @@ auto preShuffleScale(Src& src)
|
||||
|
||||
#include "run_mx_flatmm.inc"
|
||||
|
||||
template <typename FlatmmConfig>
|
||||
int run_mx_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -278,24 +282,31 @@ int run_mx_flatmm_example(int argc, char* argv[])
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(mx_prec == "fp4xfp4")
|
||||
if(mx_prec == "fp4" || mx_prec == "fp4xfp4")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::fp16_t,
|
||||
FlatmmConfig,
|
||||
MXfp4_FlatmmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only non-persistent kernels are supported currently!");
|
||||
}
|
||||
else if(mx_prec == "fp6xfp6")
|
||||
else if(mx_prec == "fp6" || mx_prec == "fp6xfp6")
|
||||
{
|
||||
throw std::runtime_error("Only support fp4xfp4 now!");
|
||||
throw std::runtime_error("fp6xfp6 is not supported.");
|
||||
}
|
||||
else if(mx_prec == "fp8xfp8")
|
||||
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
|
||||
{
|
||||
throw std::runtime_error("Only support fp4xfp4 now!");
|
||||
if(persistent_opt == 0)
|
||||
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp16_t,
|
||||
MXfp8_FlatmmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only support non-persistent kernel now!");
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -306,7 +317,6 @@ int run_mx_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -319,7 +329,7 @@ int main(int argc, char* argv[])
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return run_mx_flatmm_example<MXfp4_FlatmmConfig16>(argc, argv);
|
||||
return run_mx_flatmm_example(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
|
||||
@@ -12,4 +12,87 @@
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include "mxfp4_flatmm.hpp"
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct MXfp4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 512;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct MXfp8_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise,
|
||||
bool Splitk,
|
||||
bool HasHotLoop,
|
||||
ck_tile::TailNumber TailNum>
|
||||
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s);
|
||||
|
||||
@@ -1,24 +1,29 @@
|
||||
function(mx_flatmm_instance_generate FILE_LIST)
|
||||
set(FLATMM_CONFIG MXfp4_FlatmmConfig16)
|
||||
set(A_DATA_TYPE FP4)
|
||||
set(B_DATA_TYPE FP4)
|
||||
set(C_DATA_TYPE FP16)
|
||||
set(A_LAYOUT ROW)
|
||||
set(B_LAYOUT COL)
|
||||
set(C_LAYOUT ROW)
|
||||
set(FLATMM_CONFIG_FP4 "MXfp4_FlatmmConfig16")
|
||||
set(FLATMM_CONFIG_FP8 "MXfp8_FlatmmConfig16")
|
||||
|
||||
# foreach(PERSISTENT false true)
|
||||
# TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions.
|
||||
foreach(PERSISTENT false)
|
||||
foreach(SPLIT_K false true)
|
||||
foreach(HAS_HOT_LOOP false true)
|
||||
foreach(TAIL_NUMBER ODD EVEN)
|
||||
set(KERNEL_FILE mxgemm/mx_flatmm_instance_${PERSISTENT}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp)
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}
|
||||
@ONLY)
|
||||
list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE})
|
||||
foreach(DATA_TYPE FP4 FP8)
|
||||
set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}})
|
||||
set(A_DATA_TYPE ${DATA_TYPE})
|
||||
set(B_DATA_TYPE ${DATA_TYPE})
|
||||
foreach(SPLIT_K false true)
|
||||
foreach(HAS_HOT_LOOP false true)
|
||||
foreach(TAIL_NUMBER ODD EVEN)
|
||||
set(KERNEL_FILE mxgemm/mx_flatmm_instance_${PERSISTENT}_${DATA_TYPE}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp)
|
||||
string(TOLOWER ${KERNEL_FILE} KERNEL_FILE)
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}
|
||||
@ONLY)
|
||||
list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE})
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
// clang-format on
|
||||
|
||||
using FP4 = ck_tile::pk_fp4_t;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using FP16 = ck_tile::fp16_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
|
||||
|
||||
@@ -1,60 +0,0 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct MXfp4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 512;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise,
|
||||
bool Splitk,
|
||||
bool HasHotLoop,
|
||||
ck_tile::TailNumber TailNum>
|
||||
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s);
|
||||
@@ -88,10 +88,7 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
throw std::runtime_error("wrong! Unexpected init_method");
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffled_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffled_host.begin(), N, K);
|
||||
|
||||
const auto b_shuffled_host = preShuffleWeight<FlatmmConfig::N_Warp_Tile>(b_origin_host);
|
||||
const auto scale_a_shuffled = preShuffleScale<FlatmmConfig, true>(scale_a);
|
||||
const auto scale_b_shuffled = preShuffleScale<FlatmmConfig, false>(scale_b);
|
||||
|
||||
|
||||
@@ -41,6 +41,8 @@ using long_number = constant<v>;
|
||||
|
||||
template <bool b>
|
||||
using bool_constant = constant<b>;
|
||||
using true_type = bool_constant<true>;
|
||||
using false_type = bool_constant<false>;
|
||||
|
||||
#define CK_TILE_LEFT_UNARY_OP(OP) \
|
||||
template <auto x> \
|
||||
|
||||
@@ -21,9 +21,10 @@ namespace ck_tile {
|
||||
template <typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
typename = std::enable_if_t<std::is_class_v<TileWindow_>>>
|
||||
typename offset_t,
|
||||
typename = std::enable_if_t<std::is_class_v<TileWindow_>>>
|
||||
CK_TILE_DEVICE auto load_tile_with_offset(const TileWindow_& tile_window,
|
||||
index_t offset,
|
||||
offset_t offset,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
@@ -67,11 +68,12 @@ template <typename DistributedTensor_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
typename offset_t,
|
||||
typename = std::enable_if_t<std::is_class_v<std::remove_cv_t<DistributedTensor_>> &&
|
||||
std::is_class_v<TileWindow_>>>
|
||||
CK_TILE_DEVICE auto load_tile_with_offset(DistributedTensor_& dst_tile,
|
||||
const TileWindow_& tile_window,
|
||||
index_t offset,
|
||||
offset_t offset,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
{
|
||||
@@ -147,29 +149,31 @@ template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool static_move_ys = false,
|
||||
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>> &&
|
||||
std::is_class_v<TileWindow_>>>
|
||||
CK_TILE_DEVICE auto async_load_tile_with_offset(LdsTileWindow_&& lds_tile,
|
||||
CK_TILE_DEVICE void async_load_tile_with_offset(LdsTileWindow_&& lds_tile,
|
||||
const TileWindow_& tile_window,
|
||||
index_t offset,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> occ = {},
|
||||
bool_constant<static_move_ys> smy = {})
|
||||
{
|
||||
return tile_window.async_load_with_offset(
|
||||
offset, lds_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
tile_window.async_load_with_offset(offset, lds_tile, number<i_access>{}, occ, smy);
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
typename TileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_&& lds_tile,
|
||||
bool oob_conditional_check = true,
|
||||
bool static_move_ys = false>
|
||||
CK_TILE_DEVICE void async_load_tile(LdsTileWindow_&& lds_tile,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {})
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> occ = {},
|
||||
bool_constant<static_move_ys> smy = {})
|
||||
{
|
||||
return async_load_tile_with_offset(
|
||||
lds_tile, tile_window, 0, number<i_access>{}, bool_constant<oob_conditional_check>{});
|
||||
async_load_tile_with_offset(lds_tile, tile_window, 0, number<i_access>{}, occ, smy);
|
||||
}
|
||||
|
||||
template <typename LdsTileWindow_,
|
||||
@@ -177,19 +181,19 @@ template <typename LdsTileWindow_,
|
||||
index_t i_access = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool pre_nop = false>
|
||||
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
CK_TILE_DEVICE void async_load_tile_raw(LdsTileWindow_&& lds_tile,
|
||||
const TileWindow_& tile_window,
|
||||
number<i_access> = {},
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<pre_nop> = {})
|
||||
{
|
||||
return tile_window.async_load_raw(lds_tile,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
tile_window.async_load_raw(lds_tile,
|
||||
number<i_access>{},
|
||||
bool_constant<oob_conditional_check>{},
|
||||
bool_constant<pre_nop>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
|
||||
CK_TILE_DEVICE void async_load_fence(index_t cnt = 0)
|
||||
{
|
||||
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
|
||||
}
|
||||
|
||||
@@ -166,8 +166,8 @@ struct tensor_view
|
||||
{
|
||||
return buf_.template async_get<X>(
|
||||
smem,
|
||||
coord.get_offset() / PackedSize,
|
||||
linear_offset / PackedSize,
|
||||
coord.get_offset() / PackedSize + linear_offset / PackedSize,
|
||||
0, // linear_offset need to be imm and is not supported currently
|
||||
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
|
||||
bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
@@ -156,8 +156,10 @@ struct tile_window_with_static_distribution
|
||||
0, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
|
||||
CK_TILE_DEVICE auto load_with_offset(index_t offset,
|
||||
template <index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true,
|
||||
typename offset_t = index_t>
|
||||
CK_TILE_DEVICE auto load_with_offset(offset_t offset,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
@@ -291,14 +293,16 @@ struct tile_window_with_static_distribution
|
||||
0, dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor,
|
||||
template <typename DataType,
|
||||
typename StaticTileDistribution,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true,
|
||||
typename = std::enable_if_t<std::is_class_v<std::remove_cv_t<DistributedTensor>>>>
|
||||
CK_TILE_DEVICE auto load_with_offset(index_t offset,
|
||||
DistributedTensor& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
typename offset_t>
|
||||
CK_TILE_DEVICE void load_with_offset( //
|
||||
offset_t offset,
|
||||
static_distributed_tensor<DataType, StaticTileDistribution>& dst_tensor,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
{
|
||||
using Traits = typename Base::Traits;
|
||||
using vector_t = typename Traits::vector_t;
|
||||
@@ -306,6 +310,19 @@ struct tile_window_with_static_distribution
|
||||
|
||||
constexpr auto tile_dstr = typename Base::TileDstr{};
|
||||
|
||||
const index_t linear_off = [&]() {
|
||||
if constexpr(std::is_integral_v<offset_t>)
|
||||
return offset;
|
||||
else if constexpr(is_constant_v<offset_t>)
|
||||
return offset_t::value;
|
||||
else
|
||||
{
|
||||
auto bottom_tensor_idx_off = to_multi_index(offset_t{});
|
||||
auto bottom_tensor_coord_off = make_tensor_coordinate(
|
||||
this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off);
|
||||
return bottom_tensor_coord_off.get_offset();
|
||||
}
|
||||
}();
|
||||
// loop over thread tensor space [y0, y1, ...]
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
/// TODO: use structure binding (to be captured later) if compiled in C++20
|
||||
@@ -321,7 +338,9 @@ struct tile_window_with_static_distribution
|
||||
// read from bottom tensor
|
||||
const vector_t vec_value =
|
||||
this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, offset, bool_constant<oob_conditional_check>{});
|
||||
bottom_tensor_thread_coord,
|
||||
linear_off,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_tuple(
|
||||
@@ -514,11 +533,13 @@ struct tile_window_with_static_distribution
|
||||
template <typename LdsTileWindow_,
|
||||
index_t i_access_unsupport_ = -1,
|
||||
bool oob_conditional_check = true,
|
||||
bool static_move_ys = false,
|
||||
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>>>>
|
||||
CK_TILE_DEVICE void async_load_with_offset(index_t offset,
|
||||
LdsTileWindow_&& lds_tile,
|
||||
number<i_access_unsupport_> = {},
|
||||
bool_constant<oob_conditional_check> = {}) const
|
||||
bool_constant<oob_conditional_check> = {},
|
||||
bool_constant<static_move_ys> = {}) const
|
||||
{
|
||||
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
|
||||
using LdsDataType = typename LdsTileWindow::DataType;
|
||||
@@ -531,7 +552,7 @@ struct tile_window_with_static_distribution
|
||||
const auto window_origin = lds_tile.get_window_origin();
|
||||
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
|
||||
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
|
||||
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
|
||||
auto lds_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
|
||||
|
||||
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
|
||||
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
|
||||
@@ -543,22 +564,51 @@ struct tile_window_with_static_distribution
|
||||
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
|
||||
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
|
||||
|
||||
// Use precomputed window origin
|
||||
constexpr auto idx_ys_offset = [&]() {
|
||||
constexpr auto idx_off_ys = SFC_Ys::get_step_between(number<0>{}, iAccess);
|
||||
constexpr auto adapter_ys_offset = make_tensor_adaptor_coordinate(
|
||||
StaticTileDistribution_{}.get_ps_ys_to_xs_adaptor(),
|
||||
container_concat(array<index_t, Base::NDimP>{0},
|
||||
to_array<index_t, idx_off_ys.size()>(idx_off_ys)));
|
||||
return adapter_ys_offset.get_bottom_index();
|
||||
}();
|
||||
const auto lds_ys_offset = [&]() {
|
||||
if constexpr(static_move_ys)
|
||||
{
|
||||
const auto coord_ys_offset =
|
||||
make_tensor_coordinate(tensor_descriptor, idx_ys_offset);
|
||||
return coord_ys_offset.get_offset();
|
||||
}
|
||||
else
|
||||
return 0;
|
||||
}();
|
||||
|
||||
// Use precomputed window origin & tensor descriptor
|
||||
auto lds_bottom_tensor_thread_idx =
|
||||
window_origin + window_adaptor_warp_coord.get_bottom_index();
|
||||
|
||||
// Use precomputed tensor descriptor
|
||||
const auto lds_coord =
|
||||
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
|
||||
|
||||
// Calculate SMEM address using base pointer
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
|
||||
CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr +
|
||||
lds_coord.get_offset() / Traits::PackedSize +
|
||||
lds_ys_offset / Traits::PackedSize;
|
||||
|
||||
const auto dram_ys_offset = [&]() {
|
||||
if constexpr(static_move_ys)
|
||||
{
|
||||
const auto coord_ys_offset = make_tensor_coordinate(
|
||||
this->get_bottom_tensor_view().get_tensor_descriptor(), idx_ys_offset);
|
||||
return coord_ys_offset.get_offset();
|
||||
}
|
||||
else
|
||||
return 0;
|
||||
}();
|
||||
|
||||
// Write into bottom tensor
|
||||
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
|
||||
smem,
|
||||
bottom_tensor_thread_coord,
|
||||
offset,
|
||||
offset + dram_ys_offset,
|
||||
bool_constant<oob_conditional_check>{});
|
||||
|
||||
// Move thread coordinate if not last access
|
||||
@@ -569,11 +619,15 @@ struct tile_window_with_static_distribution
|
||||
generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
|
||||
idx_diff_ys);
|
||||
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
|
||||
if constexpr(!static_move_ys)
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_thread_coord,
|
||||
bottom_tensor_thread_coord,
|
||||
idx_diff_ps_ys);
|
||||
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys);
|
||||
if constexpr(!static_move_ys)
|
||||
Base::move_window_adaptor_and_bottom_tensor_thread_coordinate(
|
||||
window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -432,6 +432,12 @@ CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
a_m_k_scaled(m, k) = a_f4_lo * a_scale;
|
||||
a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_scaled(m, k) =
|
||||
ck_tile::type_convert<AccDataType>((a_m_k(m, k))) *
|
||||
ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ template <> struct typeToStr<fp8_t> { static constexpr const char * name = "fp8"
|
||||
template <> struct typeToStr<bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
template <> struct typeToStr<int8_t> { static constexpr const char * name = "int8"; };
|
||||
template <> struct typeToStr<pk_int4_t> { static constexpr const char * name = "pk_int4"; };
|
||||
template <> struct typeToStr<pk_fp4_t> { static constexpr const char * name = "pk_fp4"; };
|
||||
|
||||
template <memory_operation_enum MemOp> struct memOpToStr;
|
||||
template <> struct memOpToStr<memory_operation_enum::set> { static constexpr const char * name = "set"; };
|
||||
|
||||
@@ -143,16 +143,24 @@ struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, Epilog
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
constexpr index_t kKPerBlock = FlatmmPipeline::kKPerBlock;
|
||||
constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile;
|
||||
const index_t kFlatKBlocks = kargs.K / kKPerBlock;
|
||||
const index_t kFlatN = kargs.N / kNWarpTile;
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
static_assert(flatKPerBlock % FlatmmPipeline::GetVectorSizeB() == 0,
|
||||
"wrong! vector size for B tensor");
|
||||
auto&& naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(kFlatN, kFlatKBlocks, number<flatKPerBlock>{}));
|
||||
auto&& desc = transform_tensor_descriptor(
|
||||
naive_desc,
|
||||
make_tuple(make_pass_through_transform(kFlatN),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(kFlatKBlocks, number<flatKPerBlock>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global>(b_flat_ptr, desc);
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
|
||||
@@ -44,7 +44,10 @@ struct BaseFlatmmPipelineAGmemBGmemCRegV1
|
||||
else if(TailNumber::Odd == tail_num)
|
||||
return TailHandler<DispatchHotloop, TailNumber::Odd>(run_func, has_hot_loop);
|
||||
else
|
||||
{
|
||||
assert(("Wrong TailNumber!", false));
|
||||
return decltype(TailHandler<>(run_func, true, TailNumber::Even)){};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem<ADataType_,
|
||||
static constexpr int NXdlPack = 2; // it's fixed for fp4
|
||||
static constexpr int KXdlPack = 2;
|
||||
// static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp * KXdlPack;
|
||||
static constexpr index_t flatKPerWarp = 64 * ContinuousKPerThread;
|
||||
static constexpr index_t flatKPerWarp = get_warp_size() * ContinuousKPerThread;
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = MXF4FlatmmPipelineAgBgCrPolicy>
|
||||
@@ -122,9 +122,10 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr index_t BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
static constexpr index_t MXdlPack = Problem::MXdlPack;
|
||||
static constexpr index_t NXdlPack = Problem::NXdlPack;
|
||||
static constexpr index_t KXdlPack = Problem::KXdlPack;
|
||||
static constexpr index_t MXdlPack = Problem::MXdlPack;
|
||||
static constexpr index_t NXdlPack = Problem::NXdlPack;
|
||||
static constexpr index_t KXdlPack = Problem::KXdlPack;
|
||||
static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK;
|
||||
|
||||
static constexpr index_t AK1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
|
||||
static constexpr index_t BK1 = Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize;
|
||||
@@ -138,25 +139,25 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
|
||||
static constexpr index_t mfma_per_wg = 1; // 950 only
|
||||
|
||||
static constexpr index_t dsread_per_wg =
|
||||
WG::kM * WG::kK * sizeof(ADataType) / APackedSize / WaveSize / Problem::VectorLoadSize;
|
||||
static_assert((WG::kM * WG::kK * sizeof(ADataType) / APackedSize / WaveSize) %
|
||||
Problem::VectorLoadSize ==
|
||||
0);
|
||||
static constexpr index_t dsread_per_wg = WG::kM * WG::kK / AK1 / WaveSize;
|
||||
static_assert((WG::kM * WG::kK) % (AK1 * WaveSize) == 0);
|
||||
|
||||
static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp;
|
||||
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
|
||||
static constexpr index_t dswrite_num_perK = dsread_num_perK / NWarp;
|
||||
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
|
||||
static constexpr index_t Aload_num_perK = dswrite_num_perK;
|
||||
static constexpr index_t Aload_rep = dswrite_rep;
|
||||
|
||||
static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize;
|
||||
static constexpr index_t ScaleBload_K1 = NXdlPack * KXdlPack; // fixed for fp4
|
||||
static constexpr index_t Bload_num = Bload_num_perK * KIterPerWarp;
|
||||
static constexpr index_t ScaleBload_num =
|
||||
kNPerBlock * kKPerBlock / NWarp / 32 / ScaleBload_K1 / WaveSize;
|
||||
static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num;
|
||||
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
|
||||
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
|
||||
kNPerBlock * kKPerBlock / NWarp / ScaleGranularityK / NXdlPack / KXdlPack / WaveSize;
|
||||
static constexpr index_t ScaleAload_num =
|
||||
kMPerBlock * kKPerBlock / MWarp / ScaleGranularityK / MXdlPack / KXdlPack / WaveSize;
|
||||
|
||||
// static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num;
|
||||
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
|
||||
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
|
||||
|
||||
static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg;
|
||||
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
|
||||
@@ -219,7 +220,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
{
|
||||
if(inst_order[inst_idx + r * mfma_perM_perK] == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
if(inst_order[inst_idx + r * mfma_perM_perK] == 2)
|
||||
{
|
||||
@@ -234,7 +235,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
{
|
||||
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
// __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2)
|
||||
{
|
||||
@@ -470,18 +471,16 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename ScaleADramBlockWindowTmp,
|
||||
typename ScaleBDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindowTmp a_copy_dram_window,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const ScaleADramBlockWindowTmp& scale_a_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_window,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_copy_dram_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const ScaleADramBlockWindowTmp& scale_a_window,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_window,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_ping,
|
||||
void* __restrict__ p_smem_pong) const
|
||||
{
|
||||
#ifndef __gfx950__
|
||||
static_assert(false, "Only gfx950 is supported for MXFP4 flatmm pipeline now.");
|
||||
@@ -495,9 +494,8 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
|
||||
const index_t iMWarp = get_warp_id() / NWarp;
|
||||
// const index_t iNWarp = get_warp_id() % NWarp;
|
||||
// constexpr auto MIter_2nd_last = max(0, MIterPerWarp - 2);
|
||||
static_assert(NWarp == 4);
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
@@ -506,6 +504,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
auto a_dram_window =
|
||||
make_tile_window(PipelinePolicy::template MakeMXFP4_AAsyncLoadDramDescriptor<Problem>(
|
||||
a_copy_dram_window_tmp.get_bottom_tensor_view()),
|
||||
a_copy_dram_window_tmp.get_window_lengths(),
|
||||
a_copy_dram_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_ADramTileDistribution<Problem>());
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// A tile in LDS
|
||||
@@ -520,93 +525,51 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
auto a_lds_block_pong =
|
||||
make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
|
||||
|
||||
auto a_copy_lds_window_ping =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
auto a_copy_lds_window_pong =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
auto a_store_lds_window_ping = make_tile_window(
|
||||
a_lds_block_ping, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
auto a_store_lds_window_pong = make_tile_window(
|
||||
a_lds_block_pong, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// ping-pong window for A LDS
|
||||
auto a_warp_window_ping_tmp =
|
||||
auto a_warp_window_ping =
|
||||
make_tile_window(a_lds_block_ping,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
|
||||
auto a_warp_window_pong_tmp =
|
||||
auto a_warp_window_pong =
|
||||
make_tile_window(a_lds_block_pong,
|
||||
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
|
||||
{iMWarp * WG::kM, 0},
|
||||
{0, 0},
|
||||
PipelinePolicy::template MakeMXF4_ALDS_TileDistribution<Problem>());
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows_ping;
|
||||
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
auto packed_m_idx = mIter / number<MXdlPack>{};
|
||||
auto packed_m_rank = mIter % number<MXdlPack>{};
|
||||
|
||||
move_tile_window(
|
||||
a_warp_windows_ping(mIter)(kIter),
|
||||
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
|
||||
kIter * KPerBlockPerIter});
|
||||
move_tile_window(
|
||||
a_warp_windows_pong(mIter)(kIter),
|
||||
{packed_m_idx * MXdlPack * MPerBlockPerIter + packed_m_rank * WG::kM,
|
||||
kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
|
||||
// Block GEMM
|
||||
auto block_flatmm = BlockFlatmm();
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_flatmm.MakeCBlockTile();
|
||||
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution<Problem>();
|
||||
|
||||
auto b_flat_dram_window = make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
using MXFP4_B_Buffer = decltype(load_tile(b_flat_dram_window));
|
||||
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
|
||||
using V4UInt_B_Buffer = thread_buffer<uint32_t, 4>;
|
||||
union UnionBuf
|
||||
{
|
||||
V4UInt_B_Buffer u = 0;
|
||||
MXFP4_B_Buffer mxfp4;
|
||||
} ub;
|
||||
|
||||
// pingpong buffer for B
|
||||
auto b_flat_dram_windows = generate_tuple(
|
||||
[&](auto nIter) {
|
||||
constexpr auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
constexpr auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
auto window_i = make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution<Problem>());
|
||||
move_tile_window(
|
||||
window_i,
|
||||
{number<packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank>{},
|
||||
number<0>{}});
|
||||
return window_i;
|
||||
},
|
||||
number<NIterPerWarp>{});
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
|
||||
statically_indexed_array<decltype(load_tile(b_flat_dram_windows(I0))), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_flat_dram_windows;
|
||||
statically_indexed_array<statically_indexed_array<V4UInt_B_Buffer, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_ping;
|
||||
statically_indexed_array<statically_indexed_array<V4UInt_B_Buffer, KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_tensor_pong;
|
||||
b_warp_tensor_ping, b_warp_tensor_pong;
|
||||
|
||||
// pingpong buffer for Scale A and Scale B
|
||||
auto scale_a_dram_window = make_tile_window(
|
||||
@@ -649,29 +612,24 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
NIterPerWarp / NXdlPack>
|
||||
scale_b_tile_tensor_pong;
|
||||
|
||||
auto async_load_tile_ = [](auto lds, auto dram) {
|
||||
async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{});
|
||||
};
|
||||
|
||||
// HEAD
|
||||
// Prefetch A0
|
||||
auto a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
async_load_tile_(a_store_lds_window_ping, a_dram_window);
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock});
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_ping(nIter)(kIter) = ub.u;
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_windows(nIter), {0, KIterPerWarp * KFlatPerBlockPerIter});
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, KIterPerWarp * KFlatPerBlockPerIter});
|
||||
|
||||
// prefetch Scale A
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
@@ -700,71 +658,40 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
// move Scale B window to next K
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
// A_Lds_TileDist may differ with ADramTileDistribution
|
||||
auto a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Prefetch A1
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
if constexpr(HasHotLoop || TailNum == TailNumber::Even)
|
||||
{
|
||||
async_load_tile_(a_store_lds_window_pong, a_dram_window);
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock});
|
||||
}
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
clear_tile(c_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
using MXFP4_A_Buffer_ping =
|
||||
decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{})));
|
||||
// use v4i32 as the data type between basicblock to avoid unpack and repack operation.
|
||||
using V4UInt_A_Buffer = thread_buffer<uint32_t, 4>;
|
||||
union UnionBuf_A_ping
|
||||
{
|
||||
V4UInt_A_Buffer u = 0;
|
||||
MXFP4_A_Buffer_ping mxfp4;
|
||||
} ua_ping;
|
||||
|
||||
using MXFP4_A_Buffer_pong =
|
||||
decltype(load_tile(a_warp_windows_pong(number<0>{})(number<0>{})));
|
||||
union UnionBuf_A_pong
|
||||
{
|
||||
V4UInt_A_Buffer u = 0;
|
||||
MXFP4_A_Buffer_pong mxfp4;
|
||||
} ua_pong;
|
||||
statically_indexed_array<decltype(load_tile(a_warp_window_pong)), m_preload> a_warp_tensor;
|
||||
|
||||
// preload A00,A10... from lds
|
||||
statically_indexed_array<V4UInt_A_Buffer, m_preload> a_warp_tensor;
|
||||
|
||||
s_waitcnt_barrier</*vmcnt*/ dswrite_num_perK>();
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
|
||||
ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) = ua_ping.u;
|
||||
a_warp_tensor(loadIter) = load_tile_with_offset(
|
||||
a_warp_window_ping, tuple<number<mIter * WG::kM>, number<kIter * WG::kK>>{});
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// MAIN LOOP
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
auto main_body_implx2 = [&]() mutable {
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_pong(nIter)(kIter) = ub.u;
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
move_tile_window(b_flat_dram_windows(nIter),
|
||||
{0, BlockGemmShape::flatKPerBlock});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -791,15 +718,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(2i+1)
|
||||
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
|
||||
|
||||
// Prefetch A(2i+2)
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
@@ -807,30 +725,26 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
|
||||
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
|
||||
constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack;
|
||||
constexpr auto m_iter = mIter_pack * MXdlPack + imxdl;
|
||||
constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl;
|
||||
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
|
||||
constexpr auto n_iter = nIter_pack * NXdlPack + inxdl;
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() =
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<m_iter, n_iter>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
UnionBuf_A_ping ua_compute;
|
||||
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
|
||||
|
||||
UnionBuf ub_compute;
|
||||
ub_compute.u =
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl);
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
ua_compute.mxfp4,
|
||||
ub_compute.mxfp4,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0],
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
@@ -838,68 +752,60 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_pack * MXdlPack + imxdl,
|
||||
nIter_pack * NXdlPack + inxdl>{},
|
||||
merge_sequences(sequence<m_iter, n_iter>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 +
|
||||
(kIter_pack * KXdlPack + ikxdl) * 2 +
|
||||
(mIter_pack * MXdlPack + imxdl) / 2 * 4 +
|
||||
m_preload;
|
||||
constexpr auto addr =
|
||||
m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload;
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
|
||||
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
|
||||
{
|
||||
block_sync_lds();
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
|
||||
a_warp_window_ping,
|
||||
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
// barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished
|
||||
s_waitcnt< // vmcnt
|
||||
Bload_num + ScaleAload_num + ScaleBload_num>();
|
||||
block_sync_lds();
|
||||
|
||||
// Prefetch A(2i+2)
|
||||
async_load_tile_(a_store_lds_window_ping, a_dram_window);
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock});
|
||||
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
// preload A(2i+1)
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
a_warp_tensor(loadIter) = load_tile_with_offset(
|
||||
a_warp_window_pong, tuple<number<mIter * WG::kM>, number<kIter * WG::kK>>{});
|
||||
});
|
||||
HotLoopScheduler();
|
||||
|
||||
// Next K
|
||||
////////////////////////////// Next K //////////////////////////////
|
||||
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_ping(nIter)(kIter) = ub.u;
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_windows(nIter), number<kIter * KFlatPerBlockPerIter>{});
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
move_tile_window(b_flat_dram_windows(nIter),
|
||||
{0, BlockGemmShape::flatKPerBlock});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -926,15 +832,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(2i+2)
|
||||
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_ping, a_block_tile_transformed);
|
||||
|
||||
// Prefetch A(2i+3)
|
||||
a_block_tile = load_tile(a_copy_dram_window);
|
||||
// move A window to next k
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i+1
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
@@ -953,20 +850,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
UnionBuf_A_pong ua_compute;
|
||||
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
|
||||
|
||||
UnionBuf ub_compute;
|
||||
ub_compute.u =
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl);
|
||||
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
ua_compute.mxfp4,
|
||||
ub_compute.mxfp4,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
|
||||
@@ -988,39 +878,47 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_pong.mxfp4 = load_tile(
|
||||
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
|
||||
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
|
||||
{
|
||||
block_sync_lds();
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
|
||||
a_warp_window_pong,
|
||||
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
// barrier as ds_load A(2i + 1) and buffer_load_lds A(2i + 2) finished
|
||||
s_waitcnt< // vmcnt
|
||||
Bload_num + ScaleAload_num + ScaleBload_num>();
|
||||
block_sync_lds();
|
||||
|
||||
// Prefetch A(2i+3)
|
||||
async_load_tile_(a_store_lds_window_pong, a_dram_window);
|
||||
move_tile_window(a_dram_window, {0, kKPerBlock});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
// preload A(2i+2)
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
ua_ping.mxfp4 = load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) = ua_ping.u; // reload a_warp_tensor with ping buffer
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
a_warp_tensor(loadIter) = load_tile_with_offset(
|
||||
a_warp_window_ping, tuple<number<mIter * WG::kM>, number<kIter * WG::kK>>{});
|
||||
});
|
||||
HotLoopScheduler();
|
||||
};
|
||||
|
||||
iCounter--;
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
index_t iCounter = (num_loop - 1) / 2;
|
||||
do
|
||||
{
|
||||
main_body_implx2();
|
||||
iCounter--;
|
||||
} while(iCounter > 0);
|
||||
}
|
||||
|
||||
// TAIL
|
||||
@@ -1029,18 +927,9 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
auto packed_n_idx = nIter / number<NXdlPack>{};
|
||||
auto packed_n_rank = nIter % number<NXdlPack>{};
|
||||
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{packed_n_idx * NXdlPack * NFlatPerBlockPerIter + packed_n_rank,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
|
||||
ub.mxfp4 = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
b_warp_tensor_pong(nIter)(kIter) = ub.u;
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_windows(nIter),
|
||||
make_tuple(number<0>{}, number<kIter * KFlatPerBlockPerIter>{}));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1055,7 +944,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack));
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) {
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window;
|
||||
@@ -1067,10 +955,6 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
});
|
||||
|
||||
// Prefill A(loopK)
|
||||
a_block_tile_transformed = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_transformed);
|
||||
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) {
|
||||
static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) {
|
||||
@@ -1089,20 +973,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
UnionBuf_A_ping ua_compute;
|
||||
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
|
||||
|
||||
UnionBuf ub_compute;
|
||||
ub_compute.u =
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl);
|
||||
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
ua_compute.mxfp4,
|
||||
ub_compute.mxfp4,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
@@ -1124,30 +1001,28 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
|
||||
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
|
||||
{
|
||||
block_sync_lds();
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
|
||||
a_warp_window_ping,
|
||||
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
// barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished
|
||||
s_waitcnt< // vmcnt
|
||||
Bload_num + ScaleAload_num + ScaleBload_num>();
|
||||
block_sync_lds();
|
||||
|
||||
// preload A(2i+1)
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
ua_pong.mxfp4 = load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
|
||||
a_warp_tensor(loadIter) = ua_pong.u; // reload a_warp_tensor with pong buffer
|
||||
constexpr auto mIter = loadIter % MXdlPack;
|
||||
constexpr auto kIter = loadIter / MXdlPack;
|
||||
a_warp_tensor(loadIter) = load_tile_with_offset(
|
||||
a_warp_window_pong, tuple<number<mIter * WG::kM>, number<kIter * WG::kK>>{});
|
||||
});
|
||||
|
||||
Last2ndHotLoopScheduler();
|
||||
@@ -1170,19 +1045,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
UnionBuf_A_pong ua_compute;
|
||||
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
|
||||
|
||||
UnionBuf ub_compute;
|
||||
ub_compute.u =
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl);
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
ua_compute.mxfp4,
|
||||
ub_compute.mxfp4,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
scale_a_tile_tensor_pong(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_pong(nIter_pack)(kIter_pack)
|
||||
@@ -1204,18 +1073,11 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_pong.mxfp4 = load_tile(
|
||||
a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_pong.u;
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
|
||||
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
|
||||
{
|
||||
block_sync_lds();
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
|
||||
a_warp_window_pong,
|
||||
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -1244,20 +1106,13 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
UnionBuf_A_ping ua_compute;
|
||||
ua_compute.u = a_warp_tensor(number<AwarpIter>{});
|
||||
|
||||
UnionBuf ub_compute;
|
||||
ub_compute.u =
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl);
|
||||
|
||||
// warp GEMM
|
||||
WG{}.template
|
||||
operator()<ikxdl * MXdlPack + imxdl, ikxdl * NXdlPack + inxdl>(
|
||||
c_warp_tensor,
|
||||
ua_compute.mxfp4,
|
||||
ub_compute.mxfp4,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter_pack * number<NXdlPack>{} + inxdl)(
|
||||
kIter_pack * number<KXdlPack>{} + ikxdl),
|
||||
scale_a_tile_tensor_ping(mIter_pack)(kIter_pack)
|
||||
.get_thread_buffer()[0], // scale A
|
||||
scale_b_tile_tensor_ping(nIter_pack)(kIter_pack)
|
||||
@@ -1279,18 +1134,11 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
if constexpr(addr < (KIterPerWarp * MIterPerWarp) &&
|
||||
(nIter_pack == NIterPerWarp / NXdlPack - 1))
|
||||
{
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
ua_ping.mxfp4 = load_tile(
|
||||
a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
a_warp_tensor(number<AwarpIter>{}) = ua_ping.u;
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr(kIter_pack * KXdlPack + ikxdl == KIterPerWarp - 1 &&
|
||||
mIter_pack * MXdlPack + imxdl == MIter_2nd_last)
|
||||
{
|
||||
block_sync_lds();
|
||||
constexpr auto AmIter = addr % 2 + addr / 4 * 2;
|
||||
constexpr auto AkIter = addr / 2 % 2;
|
||||
a_warp_tensor(number<AwarpIter>{}) = load_tile_with_offset(
|
||||
a_warp_window_ping,
|
||||
tuple<number<AmIter * WG::kM>, number<AkIter * WG::kK>>{});
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -1299,32 +1147,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Probl
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
|
||||
else
|
||||
{
|
||||
static_assert(false, "Wrong TailNum");
|
||||
}
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename ScaleADramBlockWindowTmp,
|
||||
typename ScaleBDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const ScaleADramBlockWindowTmp& scale_a_flat_window_tmp,
|
||||
const ScaleBDramBlockWindowTmp& scale_b_flat_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType & a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
scale_a_flat_window_tmp,
|
||||
scale_b_flat_window_tmp,
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -13,22 +13,139 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
static constexpr index_t KBPerLoad = 32;
|
||||
static constexpr index_t kDramLoadPackBytes = 128;
|
||||
|
||||
static constexpr int MXdlPack = 2;
|
||||
static constexpr int NXdlPack = 2;
|
||||
static constexpr int KXdlPack = 2;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
static inline constexpr auto wg_attr_num_access =
|
||||
std::is_same_v<remove_cvref_t<typename Problem::ADataType>, pk_fp4_t>
|
||||
? WGAttrNumAccessEnum::Single
|
||||
: WGAttrNumAccessEnum::Double;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
static_assert(
|
||||
sizeof(ADataType) * numeric_traits<BDataType>::PackedSize ==
|
||||
sizeof(BDataType) * numeric_traits<ADataType>::PackedSize,
|
||||
"sizeof(ADataType) / APackedSize must be equal to sizeof(BDataType) / BPackedSize!");
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher< //
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access<Problem>>;
|
||||
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< //
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem, typename TensorView>
|
||||
CK_TILE_DEVICE static constexpr auto
|
||||
MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view)
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
static_assert(MPerXdl == 16 && NPerXdl == 16);
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
|
||||
const auto& naive_desc = naive_view.get_tensor_descriptor();
|
||||
constexpr auto ndims = remove_cvref_t<decltype(naive_desc)>::get_num_of_dimension();
|
||||
static_assert(ndims == 2, "only support 2D tensor");
|
||||
const auto rows = naive_desc.get_length(number<0>{});
|
||||
const auto cols = naive_desc.get_length(number<1>{});
|
||||
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
|
||||
const index_t K0 = cols / (K1 * K2);
|
||||
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
|
||||
|
||||
constexpr index_t M1 = 4; // so that we can use imm offset to load lds
|
||||
const index_t M0 = rows / M1;
|
||||
const auto row_lens = make_tuple(M0, number<M1>{});
|
||||
|
||||
const auto desc_0 =
|
||||
make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
|
||||
const auto desc_1 = transform_tensor_descriptor(
|
||||
desc_0,
|
||||
make_tuple(make_pass_through_transform(M0),
|
||||
make_xor_transform(make_tuple(number<M1>{}, number<K1>{})),
|
||||
make_pass_through_transform(K0),
|
||||
make_pass_through_transform(number<K2>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
|
||||
const auto desc = transform_tensor_descriptor( //
|
||||
desc_1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(row_lens),
|
||||
make_merge_transform_v3_division_mod(col_lens)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
// printf("A async load dram desc %d x %d: \n", desc.get_length(I0), desc.get_length(I1));
|
||||
|
||||
return tensor_view<typename TensorView::buffer_view,
|
||||
remove_cvref_t<decltype(desc)>,
|
||||
TensorView::DstInMemOp>{naive_view.buf_, desc};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution()
|
||||
{
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
|
||||
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
|
||||
constexpr index_t M2 = get_warp_size() / K1; // 8
|
||||
constexpr index_t M1 = BlockSize / get_warp_size(); // 4
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
|
||||
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding< //
|
||||
sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
|
||||
tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>, // M0,K0,K2
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
static_assert(MPerXdl == 16 && NPerXdl == 16);
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
|
||||
@@ -36,65 +153,70 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>() * APackedSize;
|
||||
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
|
||||
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
|
||||
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
|
||||
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
constexpr index_t M3 = 4; // so that we can use imm offset to load lds
|
||||
constexpr index_t M2 = get_warp_size() / K1 / M3; // 2
|
||||
constexpr index_t M1 = MPerXdl / (M2 * M3); // 2
|
||||
constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16
|
||||
static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!");
|
||||
|
||||
constexpr index_t Pad = 4 * K2; // 4 * 32
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
|
||||
make_tuple(number<M0>{},
|
||||
number<M1>{},
|
||||
number<K0>{},
|
||||
number<M2>{},
|
||||
number<M3>{},
|
||||
number<K1>{},
|
||||
number<K2>{}),
|
||||
make_tuple(number<M1*(K0 * (M2 * M3 * K1 * K2) + (K0 - 1) * Pad)>{},
|
||||
number<K0*(M2 * M3 * K1 * K2) + (K0 - 1) * Pad>{},
|
||||
number<M2 * M3 * K1 * K2 + Pad>{},
|
||||
number<M3 * K1 * K2>{},
|
||||
number<K1 * K2>{},
|
||||
number<K2>{},
|
||||
number<1>{}),
|
||||
number<K2>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
constexpr auto a_lds_block_desc_1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
make_tuple(make_pass_through_transform(M0),
|
||||
make_pass_through_transform(M1),
|
||||
make_pass_through_transform(K0),
|
||||
make_pass_through_transform(M2),
|
||||
make_xor_transform(make_tuple(number<M3>{}, number<K1>{})),
|
||||
make_pass_through_transform(number<K2>{})),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{}));
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
a_lds_block_desc_1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<M0>{}, number<M1>{}, number<M2>{}, number<M3>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(number<K0>{}, number<K1>{}, number<K2>{}))),
|
||||
make_tuple(sequence<0, 1, 3, 4>{}, sequence<2, 5, 6>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// return a_lds_block_desc_permuted;
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M2, M1 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution()
|
||||
{
|
||||
@@ -105,20 +227,31 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
|
||||
constexpr int M_warps = TileShape::BlockWarps::at(number<0>{});
|
||||
constexpr int N_warps = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr int M_Lane = TileShape::WarpTile::at(I0);
|
||||
constexpr int M_Lane = TileShape::WarpTile::at(I0); // 16
|
||||
|
||||
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I0); // 4
|
||||
constexpr int K_Lane = 64 / M_Lane; // 4
|
||||
|
||||
constexpr int K1 = TileShape::WarpTile::at(I2) / K_Lane; // 32
|
||||
constexpr int K_Thread = TileShape::WarpTile::at(I2) / K_Lane; // 32
|
||||
constexpr index_t num_access_v = static_cast<index_t>(wg_attr_num_access<Problem>);
|
||||
constexpr int K1 = K_Thread / num_access_v; // 16
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<N_warps>,
|
||||
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<K_Lane, K1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
std::conditional_t<
|
||||
num_access_v == 1,
|
||||
tile_distribution_encoding<
|
||||
sequence<N_warps>,
|
||||
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<K_Lane, K1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
sequence<2>,
|
||||
sequence<1>>,
|
||||
tile_distribution_encoding< //
|
||||
sequence<N_warps>,
|
||||
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<num_access_v, K_Lane, K1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 2>>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -132,25 +265,36 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t K1 = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t K0 = KWavePerBlk;
|
||||
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
constexpr index_t kKPerThread = 32;
|
||||
constexpr index_t num_access_v = static_cast<index_t>(wg_attr_num_access<Problem>);
|
||||
constexpr index_t K2 = kKPerThread / num_access_v;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>,
|
||||
sequence<KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>, // which direction
|
||||
tuple<sequence<0, 0, 0>, sequence<1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
std::conditional_t< //
|
||||
num_access_v == 1,
|
||||
tile_distribution_encoding< //
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
|
||||
sequence<K0, K1, K2>>, // 1 64 32
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>,
|
||||
tuple<sequence<0, 0, 0>, sequence<1>>,
|
||||
sequence<2>,
|
||||
sequence<2>>,
|
||||
tile_distribution_encoding< //
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
|
||||
sequence<num_access_v, K0, K1, K2>>, // 2 1 64 16
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>,
|
||||
tuple<sequence<0, 0, 1>, sequence<2>>,
|
||||
sequence<2, 2>,
|
||||
sequence<0, 3>>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -270,6 +414,21 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
return sizeof(ADataType) *
|
||||
MakeMXFP4_ALdsBlockDescriptor<Problem>().get_element_space_size() / APackedSize;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return GetSmemSizeA<Problem>();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user