This commit is contained in:
Ding, Yi
2026-03-11 23:03:20 -04:00
commit e6cd3f1e3f
6330 changed files with 1132789 additions and 0 deletions

View File

@@ -0,0 +1,278 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <type_traits>
#include "ck_tile/host.hpp"
#include "mx_flatmm.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename MXFlatmmArchTraits,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ScaleA,
typename ScaleB,
bool UsePersistentKernel = false,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::DeviceMem& b_shuffle_dev_buf,
ck_tile::DeviceMem& c_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ScaleA scale_a,
ScaleB scale_b,
int n_warmup,
int n_repeat)
{
using FlatmmConfig = typename MXFlatmmArchTraits::Config;
ck_tile::ScaleFlatmmHostArgs<ScaleA, ScaleB> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
1,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C,
scale_a,
scale_b};
using FlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<FlatmmShape,
FlatmmConfig::TileParitionerGroupNum,
FlatmmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
ALayout,
BLayout,
CLayout,
FlatmmConfig::NumWaveGroups>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, FlatmmShape, Traits>;
using BaseFlatmmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
const ck_tile::index_t k_grain = FlatmmConfig::K_Tile;
const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * k_grain;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(k_split);
const bool has_hot_loop = BaseFlatmmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseFlatmmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time = BaseFlatmmPipeline::template TailHandler<true>(
[&](auto has_hot_loop_, auto tail_num_) {
constexpr auto has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_num_v = tail_num_.value;
return mx_flatmm_calc<MXFlatmmArchTraits,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
ScaleA,
ScaleB,
UsePersistentKernel,
CDEElementWise,
false,
has_hot_loop_v,
tail_num_v>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
},
has_hot_loop,
tail_num);
constexpr int APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr int BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / 32;
std::size_t num_byte = sizeof(ADataType) * M * K / APackedSize +
sizeof(BDataType) * N * K / BPackedSize + sizeof(CDataType) * M * N +
sizeof(ck_tile::e8m0_t) * M * K / 32 +
sizeof(ck_tile::e8m0_t) * N * K / 32;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run " << ck_tile::gemm_prec_str<ADataType, BDataType>() << " Flatmm kernel " //
<< " M = " << M << " N = " << N << " K = " << K << " StrideA = " << stride_A
<< " StrideB = " << stride_B << " StrideC = " << stride_C << " : " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "32", "m dimension")
.insert("n", "512", "n dimension")
.insert("k", "256", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("mx_prec",
"fp4xfp4",
"data type for activation and weight, support: fp4xfp4, fp6xfp6, fp8xfp8, fp8xfp4 "
"and fp4xfp8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("init", "0", "0:random, 1:constant(1)")
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
.insert("warp_tile", "0", "0: 16x16x128 on gfx950.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
#include "run_mx_flatmm.inc"
int run_mx_flatmm_example(const ck_tile::ArgParser& arg_parser)
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string mx_prec = arg_parser.get_str("mx_prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
int persistent_opt = arg_parser.get_int("persistent");
std::cout << "Using default warptile of 16x16x128." << std::endl;
if(a_layout == "R" && b_layout == "C")
{
if(mx_prec == "fp4" || mx_prec == "fp4xfp4")
{
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
ck_tile::fp16_t,
MXFlatmm_GFX950_FP4FP4_Traits,
false>(arg_parser, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only non-persistent kernels are supported currently!");
}
else if(mx_prec == "fp6" || mx_prec == "fp6xfp6")
{
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::pk_fp6x16_t,
ck_tile::pk_fp6x16_t,
ck_tile::fp16_t,
MXFlatmm_GFX950_FP6FP6_Traits,
false>(arg_parser, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
{
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::fp16_t,
MXFlatmm_GFX950_FP8FP8_Traits,
false>(arg_parser, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else if(mx_prec == "fp8xfp4")
{
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
ck_tile::pk_fp4_t,
ck_tile::fp16_t,
MXFlatmm_GFX950_FP8FP4_Traits,
false>(arg_parser, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else if(mx_prec == "fp4xfp8")
{
if(persistent_opt == 0)
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
ck_tile::fp8_t,
ck_tile::fp16_t,
MXFlatmm_GFX950_FP4FP8_Traits,
false>(arg_parser, Row{}, Col{}, Row{});
else
throw std::runtime_error("Only support non-persistent kernel now!");
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
try
{
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
return run_mx_flatmm_example(arg_parser);
}
else if(warp_tile == 1)
{
throw std::runtime_error("Only support MFMA_16x16x128 now!");
}
else
{
throw std::runtime_error("Unsupported warp_tile!");
}
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,34 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "mx_flatmm_arch_traits.hpp"
template <typename MXFlatmmArchTraits,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise,
bool Splitk,
bool HasHotLoop,
ck_tile::TailNumber TailNum>
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s);

View File

@@ -0,0 +1,178 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm.hpp"
namespace ck_tile {
namespace core {
namespace arch {
// Use the amdgcn_target_id enum from arch.hpp
using TargetId = amdgcn_target_id;
} // namespace arch
} // namespace core
} // namespace ck_tile
// Base FlatmmConfig with 16x16 warp tile (for non-GFX1250)
struct MXFlatmmConfigBase16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 128;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
struct MXfp4_FlatmmConfig16 : public MXFlatmmConfigBase16
{
static constexpr ck_tile::index_t N_Tile = 512;
};
// Architecture traits for MX Flatmm - Primary template (gfx950 implementation)
template <ck_tile::core::arch::TargetId Arch, typename FlatmmConfig>
struct MXFlatmmArchTraits
{
static constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
using Config = FlatmmConfig;
template <typename MXPipelineProblem>
using MXFlatmmPipeline = ck_tile::MXFlatmmPipelineAGmemBGmemCRegV1<MXPipelineProblem>;
static constexpr int GetNLane() { return Config::N_Warp_Tile; }
template <typename dtype>
static auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
{
constexpr ck_tile::index_t NLane = Config::N_Warp_Tile;
auto src_lengths = src.get_lengths();
const int K = src_lengths[0];
const int N = src_lengths[1];
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
int KPack = std::is_same_v<dtype, ck_tile::pk_fp6x16_t>
? 32
: 16 * packed_size; // fp4/fp6:32 or fp8:16
int KLane = ck_tile::get_warp_size() / NLane;
int K0 = K / (KLane * KPack);
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K; k += packed_size)
{
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
int tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
shuffled(outputIndex) = src(k, n);
}
}
return shuffled;
}
template <bool KLast, typename dtype>
static auto preShuffleScale(ck_tile::HostTensor<dtype>& src)
{
auto src_lengths = src.get_lengths();
const auto MN = KLast ? src_lengths[0] : src_lengths[1];
const auto K = KLast ? src_lengths[1] : src_lengths[0];
size_t MNXdlPack = 2;
size_t KXdlPack = 2;
size_t XdlMNThread = Config::N_Warp_Tile; // 16
size_t XdlKThread = ck_tile::get_warp_size() / XdlMNThread;
const auto MN_Paded = ck_tile::integer_least_multiple(MN, XdlMNThread * MNXdlPack);
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({MN_Paded * K}, {1}));
size_t K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(size_t n = 0; n < MN_Paded; ++n)
{
for(size_t k = 0; k < K; ++k)
{
auto n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
auto tempn = n % (XdlMNThread * MNXdlPack);
auto n1 = tempn % XdlMNThread; // i XdlMNThread
auto n2 = tempn / XdlMNThread; // i MNXdlPack
auto k0 = k / (XdlKThread * KXdlPack); // i KRepeat
auto tempk = k % (XdlKThread * KXdlPack);
auto k1 = tempk % XdlKThread; // i XdlKThread
auto k2 = tempk / XdlKThread; // i KXdlPack
auto outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread +
n1 * MNXdlPack * KXdlPack + k2 * MNXdlPack + n2;
if constexpr(KLast)
shuffled(outputIndex) = n < MN ? src(n, k) : dtype{};
else
shuffled(outputIndex) = n < MN ? src(k, n) : dtype{};
}
}
return shuffled;
}
};
using MXFlatmm_GFX950_FP4FP4_Traits =
MXFlatmmArchTraits<ck_tile::core::arch::TargetId::GFX950, MXfp4_FlatmmConfig16>;
using MXFlatmm_GFX950_FP8FP8_Traits =
MXFlatmmArchTraits<ck_tile::core::arch::TargetId::GFX950, MXFlatmmConfigBase16>;
using MXFlatmm_GFX950_FP6FP6_Traits =
MXFlatmmArchTraits<ck_tile::core::arch::TargetId::GFX950, MXFlatmmConfigBase16>;
using MXFlatmm_GFX950_FP8FP4_Traits =
MXFlatmmArchTraits<ck_tile::core::arch::TargetId::GFX950, MXFlatmmConfigBase16>;
using MXFlatmm_GFX950_FP4FP8_Traits =
MXFlatmmArchTraits<ck_tile::core::arch::TargetId::GFX950, MXFlatmmConfigBase16>;

View File

@@ -0,0 +1,42 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
function(mx_flatmm_instance_generate FILE_LIST)
set(C_DATA_TYPE FP16)
set(A_LAYOUT ROW)
set(B_LAYOUT COL)
set(C_LAYOUT ROW)
set(MXFLATMM_ARCH)
if (GPU_TARGETS MATCHES "gfx95")
list(APPEND MXFLATMM_ARCH MXFlatmm_GFX950_)
endif()
# foreach(PERSISTENT false true)
# TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions.
foreach(PERSISTENT false)
foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP6xFP6 FP8xFP4 FP4xFP8)
string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE})
list(GET DATA_TYPE_AB 0 A_DATA_TYPE)
list(GET DATA_TYPE_AB 1 B_DATA_TYPE)
foreach(ARCH ${MXFLATMM_ARCH})
set(MXFLATMM_ARCH_TRAITS "${ARCH}${A_DATA_TYPE}${B_DATA_TYPE}_Traits")
foreach(SPLIT_K false true)
foreach(HAS_HOT_LOOP false true)
foreach(TAIL_NUMBER ODD EVEN)
set(KERNEL_FILE mxgemm/instance_${ARCH}${DATA_TYPE}_${PERSISTENT}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp)
string(TOLOWER ${KERNEL_FILE} KERNEL_FILE)
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in
${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}
@ONLY)
list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE})
endforeach()
endforeach()
endforeach()
endforeach()
endforeach()
endforeach()
set(${FILE_LIST} ${${FILE_LIST}} PARENT_SCOPE)
endfunction()

View File

@@ -0,0 +1,57 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "mx_flatmm_instance.hpp"
// clang-format off
#define MXFLATMM_ARCH_TRAITS @MXFLATMM_ARCH_TRAITS@
#define A_DATA_TYPE @A_DATA_TYPE@
#define B_DATA_TYPE @B_DATA_TYPE@
#define C_DATA_TYPE @C_DATA_TYPE@
#define A_LAYOUT @A_LAYOUT@
#define B_LAYOUT @B_LAYOUT@
#define C_LAYOUT @C_LAYOUT@
#define PERSISTENT @PERSISTENT@
#define SPLIT_K @SPLIT_K@
#define HAS_HOT_LOOP @HAS_HOT_LOOP@
#define TAIL_NUMBER @TAIL_NUMBER@
// clang-format on
using FP4 = ck_tile::pk_fp4_t;
using FP8 = ck_tile::fp8_t;
using FP6 = ck_tile::pk_fp6x16_t;
using FP16 = ck_tile::fp16_t;
using BF16 = ck_tile::bf16_t;
using ROW = ck_tile::tensor_layout::gemm::RowMajor;
using COL = ck_tile::tensor_layout::gemm::ColumnMajor;
using ScaleType = ck_tile::e8m0_t;
inline constexpr auto ODD = ck_tile::TailNumber::Odd;
inline constexpr auto EVEN = ck_tile::TailNumber::Even;
inline constexpr int ScaleGranularityM = 1;
inline constexpr int ScaleGranularityN = 1;
inline constexpr int ScaleGranularityK = 32;
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>;
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>;
template float mx_flatmm_calc<MXFLATMM_ARCH_TRAITS,
A_DATA_TYPE,
B_DATA_TYPE,
/*DsDatatype*/ ck_tile::tuple<>,
/*AccDataType*/ float,
C_DATA_TYPE,
A_LAYOUT,
B_LAYOUT,
/*DsLayout*/ ck_tile::tuple<>,
C_LAYOUT,
ScaleM,
ScaleN,
PERSISTENT,
/*CDEElementWise*/ ck_tile::element_wise::PassThrough,
SPLIT_K,
HAS_HOT_LOOP,
TAIL_NUMBER>(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s);

View File

@@ -0,0 +1,174 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <type_traits>
#include "ck_tile/host.hpp"
#include "mx_flatmm.hpp"
template <typename Layout>
using is_row_major_t = ck_tile::bool_constant<
std::is_same_v<ck_tile::remove_cvref_t<Layout>, ck_tile::tensor_layout::gemm::RowMajor>>;
template <typename MXFlatmmArchTraits,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise,
bool Splitk,
bool HasHotLoop,
ck_tile::TailNumber TailNum>
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s)
{
using FlatmmConfig = typename MXFlatmmArchTraits::Config;
using FlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile>>;
using MXGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
FlatmmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
FlatmmConfig::TransposeC,
FlatmmConfig::UseStructuredSparsity,
persistent,
FlatmmConfig::NumWaveGroups,
true>;
using ComputeDataType = ADataType;
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
constexpr auto scheduler = FlatmmConfig::Scheduler;
ck_tile::ignore = Splitk;
// determined by scale shuffle pattern
constexpr int BlockedXDLN_PerWarp = MXFlatmmArchTraits::BlockedXDLN_PerWarp;
using MXPipelineProblem = ck_tile::MXFlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
FlatmmShape,
MXGemmTraits,
scheduler,
HasHotLoop,
TailNum>;
using MXFlatmmPipeline =
typename MXFlatmmArchTraits::template MXFlatmmPipeline<MXPipelineProblem>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<FlatmmShape,
FlatmmConfig::TileParitionerGroupNum,
FlatmmConfig::TileParitionerM01>;
using GemmEpilogue =
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<ComputeDataType,
ComputeDataType,
DsDatatype,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
FlatmmConfig::M_Warp,
FlatmmConfig::N_Warp,
FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
MXPipelineProblem::TransposeC,
FlatmmConfig::NumWaveGroups,
false, // FixedVectorSize
1, // VectorSizeC
FlatmmConfig::TiledMMAPermuteN,
BlockedXDLN_PerWarp>>;
using Kernel = ck_tile::MXFlatmmKernel<TilePartitioner, MXFlatmmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << FlatmmShape::GetName() << "\n"
<< "Shape: " << FlatmmShape::GetName() << "\n"
<< "problem: " << MXPipelineProblem::GetName() << "\n"
<< "pipeline: " << MXFlatmmPipeline::GetName() << "\n"
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
// Declare rotating_mem_ptr here so it stays in scope until it is needed
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
std::function<void()> preprocess;
auto clear_gemm_output = [&]() {
if(args.k_batch > 1)
hipGetErrorString(
hipMemsetAsync(args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major_t<ALayout>{}));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major_t<BLayout>{}));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem_ptr->Print();
preprocess = [&]() {
ck_tile::flush_icache();
rotating_mem_ptr->Next();
clear_gemm_output();
};
}
else
{
preprocess = clear_gemm_output;
}
return ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}

View File

@@ -0,0 +1,182 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
template <typename PrecActType,
typename PrecWeightType,
typename CDataType,
typename MXFlatmmArchTraits,
bool UsePersistentKernel = false,
typename ALayout,
typename BLayout,
typename CLayout>
int run_mx_flatmm_with_layouts(const ck_tile::ArgParser& arg_parser,
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
const CLayout c_layout = CLayout{})
{
using ADataType = PrecActType;
using BDataType = PrecWeightType;
using AccDataType = float;
using ScaleType = ck_tile::e8m0_t;
constexpr int ScaleGranularityM = 1;
constexpr int ScaleGranularityN = 1;
constexpr int ScaleGranularityK = 32;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t init_method = arg_parser.get_int("init");
ck_tile::index_t n_warmup = arg_parser.get_int("warmup");
ck_tile::index_t n_repeat = arg_parser.get_int("repeat");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout));
auto scale_stride_A = ck_tile::get_default_stride(
M / ScaleGranularityM, K / ScaleGranularityK, 0, is_row_major(a_layout));
auto scale_stride_B = ck_tile::get_default_stride(
K / ScaleGranularityK, N / ScaleGranularityN, 0, is_row_major(b_layout));
if(K % ScaleGranularityK != 0)
throw std::runtime_error("wrong! K must be multiple of ScaleGranularityK.");
if(K % ck_tile::numeric_traits<ADataType>::PackedSize != 0 ||
K % ck_tile::numeric_traits<BDataType>::PackedSize != 0)
throw std::runtime_error("wrong! K must be multiple of packed size.");
ck_tile ::HostTensor<ADataType> a_host(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_origin_host(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_rslt_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::HostTensor<ScaleType> scale_a(ck_tile::host_tensor_descriptor(
M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::host_tensor_descriptor(
K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout)));
if constexpr(std::is_same_v<ADataType, ck_tile::pk_fp6x16_t>)
{
auto a_buffer_bytes = a_host.get_element_space_size_in_bytes();
auto b_buffer_bytes = b_origin_host.get_element_space_size_in_bytes();
ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_b);
std::vector<int8_t> random_bufA(a_buffer_bytes);
std::vector<int8_t> random_bufB(b_buffer_bytes);
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<int> dis(1, 4);
for(size_t i = 0; i < a_buffer_bytes; ++i)
random_bufA[i] = static_cast<int8_t>(dis(gen));
for(size_t i = 0; i < b_buffer_bytes; ++i)
random_bufB[i] = static_cast<int8_t>(dis(gen));
memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes);
memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes);
}
else
{
if(init_method == 0)
{
ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a);
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b);
}
else if(init_method == 1)
{
ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a);
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b);
}
else
{
throw std::runtime_error("wrong! Unexpected init_method");
}
}
const auto b_shuffled_host = MXFlatmmArchTraits::preShuffleWeight(b_origin_host);
const auto scale_a_shuffled = MXFlatmmArchTraits::template preShuffleScale<true>(scale_a);
const auto scale_b_shuffled = MXFlatmmArchTraits::template preShuffleScale<false>(scale_b);
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_shuffled_dev_buf(b_shuffled_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes());
ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes());
a_dev_buf.ToDevice(a_host.data());
b_shuffled_dev_buf.ToDevice(b_shuffled_host.data());
c_rslt_host.SetZero();
scale_a_dev_buf.ToDevice(scale_a_shuffled.data());
scale_b_dev_buf.ToDevice(scale_b_shuffled.data());
auto scale_a_dev_ptr =
ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>{
static_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM};
auto scale_b_dev_ptr =
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>{
static_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
invoke_mx_flatmm<MXFlatmmArchTraits,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
decltype(scale_a_dev_ptr),
decltype(scale_b_dev_ptr),
UsePersistentKernel>(a_dev_buf,
b_shuffled_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
scale_a_dev_ptr,
scale_b_dev_ptr,
n_warmup,
n_repeat);
c_dev_buf.FromDevice(c_rslt_host.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_mx_gemm<ADataType, BDataType, ScaleType, AccDataType, CDataType>(
a_host, b_origin_host, c_m_n_host_ref, scale_a, scale_b);
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
pass = ck_tile::check_err(
c_rslt_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
<< std::endl;
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
}
return pass ? 0 : -1;
}