mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Merge commit 'e135dd518d19a36466ce7c61bb9d3203ec18c8af' into develop
This commit is contained in:
@@ -14,6 +14,7 @@ if(has_supported_gpu)
|
||||
add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp)
|
||||
add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp)
|
||||
add_executable(tile_example_grouped_flatmm EXCLUDE_FROM_ALL grouped_flatmm.cpp)
|
||||
add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp) # TODO: 950 only
|
||||
|
||||
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
|
||||
set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS)
|
||||
@@ -27,6 +28,6 @@ if(has_supported_gpu)
|
||||
target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
target_compile_options(tile_example_mx_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS}) # TODO: 950 only
|
||||
endif()
|
||||
|
||||
|
||||
506
example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp
Normal file
506
example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp
Normal file
@@ -0,0 +1,506 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "mx_flatmm.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise>
|
||||
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
persistent,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>;
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
|
||||
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::MXFlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using CodegenMXFlatmmPipeline =
|
||||
ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
|
||||
using Kernel =
|
||||
ck_tile::MXFlatmmKernel<TilePartitioner, CodegenMXFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenMXFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ScaleA,
|
||||
typename ScaleB,
|
||||
bool UsePersistentKernel = false,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::DeviceMem& b_shuffle_dev_buf,
|
||||
ck_tile::DeviceMem& c_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
ScaleA scale_a,
|
||||
ScaleB scale_b,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::ScaleFlatmmHostArgs<ScaleA, ScaleB> args = {a_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C,
|
||||
scale_a,
|
||||
scale_b};
|
||||
|
||||
float ave_time = mx_flatmm_calc<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ScaleA,
|
||||
ScaleB,
|
||||
UsePersistentKernel,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
constexpr int APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr int BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / 32;
|
||||
std::size_t num_byte = sizeof(ADataType) * M * K / APackedSize +
|
||||
sizeof(BDataType) * N * K / BPackedSize + sizeof(CDataType) * M * N +
|
||||
sizeof(ck_tile::e8m0_t) * M * K / 32 +
|
||||
sizeof(ck_tile::e8m0_t) * N * K / 32;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run MXFP4_Flatmm kernel " //
|
||||
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
|
||||
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "32", "m dimension")
|
||||
.insert("n", "128", "n dimension")
|
||||
.insert("k", "256", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert(
|
||||
"mx_prec", "fp4xfp4", "data type for activation and weight, support: fp6xfp6, fp8xfp8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:constant(1)")
|
||||
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, class IterSrc, class IterDst>
|
||||
void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = FlatmmConfig::N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K_pk; ++k)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, bool KLast, typename Src>
|
||||
auto preShuffleScale(Src& src)
|
||||
{
|
||||
using dtype = typename Src::Data::value_type;
|
||||
auto src_lengths = src.get_lengths();
|
||||
const auto MN = KLast ? src_lengths[0] : src_lengths[1];
|
||||
const auto K = KLast ? src_lengths[1] : src_lengths[0];
|
||||
|
||||
size_t MNXdlPack = 2;
|
||||
size_t KXdlPack = 2;
|
||||
size_t XdlMNThread = FlatmmConfig::N_Warp_Tile; // 16
|
||||
size_t XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
const auto MN_Paded = ck_tile::integer_least_multiple(MN, XdlMNThread * MNXdlPack);
|
||||
|
||||
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({MN_Paded * K}, {1}));
|
||||
|
||||
size_t K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(size_t n = 0; n < MN_Paded; ++n)
|
||||
{
|
||||
for(size_t k = 0; k < K; ++k)
|
||||
{
|
||||
auto n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
auto tempn = n % (XdlMNThread * MNXdlPack);
|
||||
auto n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
auto n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
auto k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
auto tempk = k % (XdlKThread * KXdlPack);
|
||||
auto k1 = tempk % XdlKThread; // i XdlKThread
|
||||
auto k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
auto outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
|
||||
if constexpr(KLast)
|
||||
shuffled(outputIndex) = n < MN ? src(n, k) : dtype{};
|
||||
else
|
||||
shuffled(outputIndex) = n < MN ? src(k, n) : dtype{};
|
||||
}
|
||||
}
|
||||
return shuffled;
|
||||
}
|
||||
|
||||
#include "run_mx_flatmm.inc"
|
||||
|
||||
template <typename FlatmmConfig>
|
||||
int run_mx_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string mx_prec = arg_parser.get_str("mx_prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
int persistent_opt = arg_parser.get_int("persistent");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(mx_prec == "fp4xfp4")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::fp16_t,
|
||||
FlatmmConfig,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::fp16_t,
|
||||
FlatmmConfig,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else if(mx_prec == "fp6xfp6")
|
||||
{
|
||||
throw std::runtime_error("Only support fp4xfp4 now!");
|
||||
}
|
||||
else if(mx_prec == "fp8xfp8")
|
||||
{
|
||||
throw std::runtime_error("Only support fp4xfp4 now!");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_mx_flatmm_example<MXfp4_FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
throw std::runtime_error("Only support MFMA_16x16x128 now!");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported warp_tile!");
|
||||
}
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
15
example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp
Normal file
15
example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp
Normal file
@@ -0,0 +1,15 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include "mxfp4_flatmm.hpp"
|
||||
40
example/ck_tile/18_flatmm/mxgemm/mxfp4_flatmm.hpp
Normal file
40
example/ck_tile/18_flatmm/mxgemm/mxfp4_flatmm.hpp
Normal file
@@ -0,0 +1,40 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct MXfp4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 512;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
167
example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc
Normal file
167
example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc
Normal file
@@ -0,0 +1,167 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
template <typename PrecActType,
|
||||
typename PrecWeightType,
|
||||
typename CDataType,
|
||||
typename FlatmmConfig,
|
||||
bool UsePersistentKernel = false,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_mx_flatmm_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using AccDataType = float;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
|
||||
constexpr int ScaleGranularityM = 1;
|
||||
constexpr int ScaleGranularityN = 1;
|
||||
constexpr int ScaleGranularityK = 32;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
ck_tile::index_t n_warmup = arg_parser.get_int("warmup");
|
||||
ck_tile::index_t n_repeat = arg_parser.get_int("repeat");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout));
|
||||
|
||||
auto scale_stride_A = ck_tile::get_default_stride(
|
||||
M / ScaleGranularityM, K / ScaleGranularityK, 0, is_row_major(a_layout));
|
||||
auto scale_stride_B = ck_tile::get_default_stride(
|
||||
K / ScaleGranularityK, N / ScaleGranularityN, 0, is_row_major(b_layout));
|
||||
|
||||
if(K % ScaleGranularityK != 0)
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleGranularityK.");
|
||||
if(K % ck_tile::numeric_traits<ADataType>::PackedSize != 0 ||
|
||||
K % ck_tile::numeric_traits<BDataType>::PackedSize != 0)
|
||||
throw std::runtime_error("wrong! K must be multiple of packed size.");
|
||||
|
||||
ck_tile ::HostTensor<ADataType> a_host(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_origin_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_rslt_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_a(ck_tile::host_tensor_descriptor(
|
||||
M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::host_tensor_descriptor(
|
||||
K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout)));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! Unexpected init_method");
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffled_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffled_host.begin(), N, K);
|
||||
|
||||
const auto scale_a_shuffled = preShuffleScale<FlatmmConfig, true>(scale_a);
|
||||
const auto scale_b_shuffled = preShuffleScale<FlatmmConfig, false>(scale_b);
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_shuffled_dev_buf(b_shuffled_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes());
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
b_shuffled_dev_buf.ToDevice(b_shuffled_host.data());
|
||||
c_rslt_host.SetZero();
|
||||
scale_a_dev_buf.ToDevice(scale_a_shuffled.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffled.data());
|
||||
|
||||
auto scale_a_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>{
|
||||
static_cast<float*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM};
|
||||
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
|
||||
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
|
||||
|
||||
invoke_mx_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
decltype(scale_a_dev_ptr),
|
||||
decltype(scale_b_dev_ptr),
|
||||
UsePersistentKernel>(a_dev_buf,
|
||||
b_shuffled_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
scale_a_dev_ptr,
|
||||
scale_b_dev_ptr,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_dev_buf.FromDevice(c_rslt_host.data());
|
||||
|
||||
bool pass = true;
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_mx_gemm<ADataType, BDataType, ScaleType, AccDataType, CDataType>(
|
||||
a_host, b_origin_host, c_m_n_host_ref, scale_a, scale_b);
|
||||
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
|
||||
const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_rslt_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
@@ -382,6 +382,93 @@ reference_gemm_multiple_abd(const std::array<HostTensor<ADataType>, AsDataType::
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename ScaleDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename AElementOp = ck_tile::identity,
|
||||
typename BElementOp = ck_tile::identity,
|
||||
typename ACCElementOp = ck_tile::identity>
|
||||
CK_TILE_HOST void reference_mx_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
const HostTensor<BDataType>& b_k_n,
|
||||
HostTensor<CDataType>& c_m_n,
|
||||
const HostTensor<ScaleDataType>& scale_a,
|
||||
const HostTensor<ScaleDataType>& scale_b,
|
||||
const AElementOp& = {},
|
||||
const BElementOp& = {},
|
||||
const ACCElementOp& = {})
|
||||
{
|
||||
static_assert(std::is_same_v<AElementOp, ck_tile::identity>);
|
||||
static_assert(std::is_same_v<BElementOp, ck_tile::identity>);
|
||||
static_assert(std::is_same_v<ACCElementOp, ck_tile::identity>);
|
||||
|
||||
const std::size_t M = a_m_k.get_length(0);
|
||||
const std::size_t N = b_k_n.get_length(1);
|
||||
const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
const std::size_t ScaleBlockSize = K / scale_a.get_length(1);
|
||||
|
||||
HostTensor<AccDataType> a_m_k_scaled({std::size_t(M), std::size_t(K)},
|
||||
{std::size_t(K), std::size_t(1)});
|
||||
HostTensor<AccDataType> b_k_n_scaled({std::size_t(K), std::size_t(N)},
|
||||
{std::size_t(1), std::size_t(K)});
|
||||
|
||||
for(std::size_t m = 0; m < M; ++m)
|
||||
{
|
||||
for(std::size_t k = 0; k < K; ++k)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
if(k % 2 == 1)
|
||||
continue; // skip odd k
|
||||
|
||||
auto a_f4x2 = a_m_k(m, k);
|
||||
auto a_scale = ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
|
||||
auto a_f4_lo =
|
||||
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<0>{}));
|
||||
auto a_f4_hi =
|
||||
ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<1>{}));
|
||||
|
||||
a_m_k_scaled(m, k) = a_f4_lo * a_scale;
|
||||
a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for(std::size_t n = 0; n < N; n++)
|
||||
{
|
||||
for(std::size_t k = 0; k < K; k++)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
if(k % 2 == 1)
|
||||
continue; // skip odd k
|
||||
|
||||
auto b_f4x2 = b_k_n(k, n);
|
||||
auto b_scale = ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
|
||||
auto b_f4_lo =
|
||||
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<0>{}));
|
||||
auto b_f4_hi =
|
||||
ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<1>{}));
|
||||
|
||||
b_k_n_scaled(k, n) = b_f4_lo * b_scale;
|
||||
b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_scaled(k, n) =
|
||||
ck_tile::type_convert<AccDataType>((b_k_n(k, n))) *
|
||||
ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// call reference gemm
|
||||
reference_gemm<AccDataType, AccDataType, AccDataType, CDataType>(
|
||||
a_m_k_scaled, b_k_n_scaled, c_m_n);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
|
||||
@@ -13,11 +13,14 @@
|
||||
#include "ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp"
|
||||
#include "ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp"
|
||||
#include "ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp"
|
||||
#include "ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
|
||||
@@ -902,8 +902,8 @@ struct FlatmmKernel
|
||||
{
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
|
||||
518
include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp
Normal file
518
include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp
Normal file
@@ -0,0 +1,518 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
|
||||
struct MXFlatmmKernel : FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_>
|
||||
{
|
||||
using Underlying = FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using FlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>;
|
||||
using BlockGemmShape =
|
||||
remove_cvref_t<typename MXFlatmmPipeline_::BlockGemmShape>; // TileFlatmmShape
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
|
||||
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
|
||||
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
|
||||
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
static constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(number<0>{});
|
||||
static constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{});
|
||||
static constexpr int KThreadPerXdl = 64 / MThreadPerXdl;
|
||||
|
||||
static constexpr int APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
static constexpr int BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
static constexpr int MXdlPack = FlatmmPipeline::MXdlPack;
|
||||
static constexpr int NXdlPack = FlatmmPipeline::NXdlPack;
|
||||
static constexpr int KXdlPack = FlatmmPipeline::KXdlPack;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto I3 = number<3>();
|
||||
static constexpr auto I4 = number<4>();
|
||||
static constexpr auto I5 = number<5>();
|
||||
|
||||
static_assert(DsLayout::size() == DsDataType::size(),
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
// using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "mx_flatmm_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
|
||||
{
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
hipDeviceProp_t prop;
|
||||
int deviceId = 0; // default device
|
||||
|
||||
constexpr int block_size = MXFlatmmKernel::BlockSize().x;
|
||||
int dync_smem_size = 0;
|
||||
int maxActiveBlocksPerCU = 0;
|
||||
|
||||
if(hipGetDeviceProperties(&prop, deviceId) != hipSuccess)
|
||||
throw std::runtime_error(std::string("hipGetDeviceProperties failed: ") +
|
||||
hipGetErrorName(hipGetLastError()));
|
||||
|
||||
if(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&maxActiveBlocksPerCU,
|
||||
reinterpret_cast<void*>(
|
||||
kentry<1, MXFlatmmKernel, remove_cvref_t<decltype(kargs)>>),
|
||||
block_size,
|
||||
dync_smem_size) != hipSuccess)
|
||||
throw std::runtime_error(
|
||||
std::string("hipOccupancyMaxActiveBlocksPerMultiprocessor failed: ") +
|
||||
hipGetErrorName(hipGetLastError()));
|
||||
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
// << ", persistent_block_size: " << persistent_block_size
|
||||
// << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
|
||||
|
||||
if(kargs.k_batch != 1)
|
||||
throw std::runtime_error("Wrong! k_batch != 1 not supported in persistent kernel");
|
||||
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
|
||||
}
|
||||
}
|
||||
|
||||
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
|
||||
make_tuple(kargs.stride_A, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
make_tuple(kFlatN, kFlatK),
|
||||
make_tuple(kFlatK, 1),
|
||||
number<FlatmmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DDataType_*>(ds_ptr[i]),
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_Ds[i], 1),
|
||||
number<EpiloguePipeline::GetVectorSizeD(i)>{},
|
||||
number<1>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& e_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.N, kargs.M),
|
||||
make_tuple(kargs.stride_E, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto scale_a = kargs.scale_m_ptr;
|
||||
auto scale_b = kargs.scale_n_ptr;
|
||||
|
||||
static constexpr int BlockScaleSize = 32; // decltype(scale_n)::GranularityK;
|
||||
const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl));
|
||||
const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl));
|
||||
const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl);
|
||||
|
||||
// A scale tensor view
|
||||
const auto& scale_a_tensor_view = [&]() {
|
||||
// Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
|
||||
const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl));
|
||||
const auto scale_a_desc = transform_tensor_descriptor(
|
||||
scale_a_naive_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_packs_m, MThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_a.ptr), scale_a_desc);
|
||||
}();
|
||||
|
||||
// B scale tensor view
|
||||
const auto& scale_b_tensor_view = [&]() {
|
||||
const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl));
|
||||
const auto scale_b_desc = transform_tensor_descriptor(
|
||||
scale_b_navie_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)),
|
||||
make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const int32_t*>(scale_b.ptr), scale_b_desc);
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view,
|
||||
b_flat_tensor_view,
|
||||
ds_tensor_view,
|
||||
e_tensor_view,
|
||||
scale_a_tensor_view,
|
||||
scale_b_tensor_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
|
||||
{
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_flat_tensor_view = views.at(I1);
|
||||
|
||||
const auto& ds_pad_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto& d_tensor_view = views.at(I2);
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(d_tensor_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadM>{});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& e_pad_view = [&]() {
|
||||
const auto& e_tensor_view = views.at(I3);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, FlatmmPipeline::kPadN>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<FlatmmPipeline::kPadM, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(
|
||||
a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4), views.at(I5));
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& b_flat_pad_view = views.at(I1);
|
||||
const auto& ds_pad_view = views.at(I2);
|
||||
const auto& e_pad_view = views.at(I3);
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{0, i_m});
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_flat_block_window =
|
||||
make_tile_window(b_flat_pad_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
|
||||
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(ds_pad_view[i],
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_n, i_m});
|
||||
}
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto e_block_window = make_tile_window(
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
static constexpr int BlockScaleSize = 32;
|
||||
|
||||
auto scale_a_block_window = make_tile_window(
|
||||
views.at(I4),
|
||||
make_tuple(number<TilePartitioner::MPerBlock / MXdlPack>{},
|
||||
number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
|
||||
{i_m / MXdlPack, 0});
|
||||
|
||||
auto scale_b_block_window = make_tile_window(
|
||||
views.at(I5),
|
||||
make_tuple(number<TilePartitioner::NPerBlock / NXdlPack>{},
|
||||
number<TilePartitioner::KPerBlock / (BlockScaleSize * KXdlPack)>{}),
|
||||
{i_n / NXdlPack, 0});
|
||||
|
||||
return make_tuple(a_block_window,
|
||||
b_flat_block_window,
|
||||
ds_block_window,
|
||||
e_block_window,
|
||||
scale_a_block_window,
|
||||
scale_b_block_window);
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void
|
||||
RunFlatmm(const ADataType* a_ptr,
|
||||
const BDataType* b_flat_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr_ping,
|
||||
void* smem_ptr_pong,
|
||||
const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& scale_a_block_window = gemm_tile_windows.at(I4);
|
||||
const auto& scale_b_block_window = gemm_tile_windows.at(I5);
|
||||
|
||||
static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
|
||||
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|
||||
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
|
||||
"ScaleM and ScaleN should have the same GranularityK");
|
||||
constexpr bool DoEpiScale =
|
||||
(ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
|
||||
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
|
||||
|
||||
auto a_block_window_with_distr =
|
||||
ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
|
||||
a_block_window.get_window_lengths(),
|
||||
a_block_window.get_window_origin(),
|
||||
FlatmmPipeline::GetADramTileDistribution());
|
||||
const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
|
||||
b_flat_block_window,
|
||||
scale_a_block_window,
|
||||
scale_b_block_window,
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(DoEpiScale)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
d_block_window,
|
||||
smem_ptr_ping,
|
||||
kargs.scale_m_ptr + block_idx_m,
|
||||
kargs.scale_n_ptr + block_idx_n);
|
||||
}
|
||||
else if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
}
|
||||
}
|
||||
|
||||
template <class ScaleM, class ScaleN>
|
||||
CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
|
||||
int partition_idx = blockIdx.x) const
|
||||
{
|
||||
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
|
||||
do
|
||||
{
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const SplitKBatchOffset splitk_batch_offset(kargs);
|
||||
// options
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr) +
|
||||
splitk_batch_offset.a_k_split_offset / APackedSize;
|
||||
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
|
||||
splitk_batch_offset.b_k_split_offset / BPackedSize;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()];
|
||||
__shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()];
|
||||
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
|
||||
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
|
||||
b_flat_ptr,
|
||||
kargs.ds_ptr,
|
||||
e_ptr,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false,
|
||||
"Unimplemented: atomic_add with odd vector size for fp16/bf16");
|
||||
}
|
||||
partition_idx += gridDim.x;
|
||||
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -291,10 +291,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t M0 = MPerBlock / M1;
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
|
||||
constexpr index_t M0 = MPerBlock / M1;
|
||||
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
|
||||
static_assert(total_pixels % M1 == 0);
|
||||
constexpr index_t K3 = total_pixels / M1;
|
||||
@@ -331,7 +333,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType) * APackedSize;
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
// coalesce reading for each blocks
|
||||
if constexpr(get_warp_size() % K0 == 0)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,275 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
static constexpr index_t KBPerLoad = 32;
|
||||
|
||||
static constexpr int MXdlPack = 2;
|
||||
static constexpr int NXdlPack = 2;
|
||||
static constexpr int KXdlPack = 2;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
static_assert(MPerXdl == 16 && NPerXdl == 16);
|
||||
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
|
||||
|
||||
/*reduce transform layers,compare with old ck*/
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>() * APackedSize;
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<MPerBlock>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// return a_lds_block_desc_permuted;
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M2, M1 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
|
||||
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
|
||||
|
||||
constexpr int M_warps = TileShape::BlockWarps::at(number<0>{});
|
||||
constexpr int N_warps = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr int M_Lane = TileShape::WarpTile::at(I0);
|
||||
|
||||
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I0); // 4
|
||||
|
||||
constexpr int K1 = TileShape::WarpTile::at(I2) / K_Lane; // 32
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<N_warps>,
|
||||
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<K_Lane, K1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 0>, sequence<0, 2>>,
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_BFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16");
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat>,
|
||||
tuple<sequence<NWavePerBlk, NXdlPack>,
|
||||
sequence<KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<2>>, // which direction
|
||||
tuple<sequence<0, 0, 0>, sequence<1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t kMPerBlock = TileShape::BlockTile::at(I0);
|
||||
|
||||
constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
|
||||
constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
|
||||
|
||||
static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size");
|
||||
|
||||
constexpr index_t M_Lanes = TileShape::WarpTile::at(I0);
|
||||
constexpr index_t K_Lanes = 64 / M_Lanes;
|
||||
|
||||
// Y dimension (M) decomposition
|
||||
constexpr index_t Y2 = M_Lanes;
|
||||
constexpr index_t Y1 = M_Warps;
|
||||
constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2);
|
||||
|
||||
// X dimension (K) decomposition
|
||||
constexpr index_t X0 = K_Lanes;
|
||||
constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<N_Warps>, // repeat N_warps
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_DramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
constexpr index_t kNPerBlock = TileShape::BlockTile::at(I1);
|
||||
|
||||
constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
|
||||
constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
|
||||
|
||||
static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size");
|
||||
|
||||
constexpr index_t N_Lanes = TileShape::WarpTile::at(I1);
|
||||
constexpr index_t K_Lanes = 64 / N_Lanes;
|
||||
|
||||
// Y dimension (M) decomposition
|
||||
constexpr index_t Y2 = N_Lanes;
|
||||
constexpr index_t Y1 = N_Warps;
|
||||
constexpr index_t Y0 = kNPerBlock / (NXdlPack * Y1 * Y2);
|
||||
|
||||
// X dimension (K) decomposition
|
||||
constexpr index_t X0 = K_Lanes;
|
||||
constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<M_Warps>, // ?
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_FlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
constexpr index_t M_Warp = TileShape::BlockWarps::at(number<0>{});
|
||||
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I0);
|
||||
constexpr index_t M_Lane = TileShape::WarpTile::at(I0);
|
||||
constexpr index_t N_Wrap = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr index_t MWavePerBlk = M_Warp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<N_Wrap>, // ?
|
||||
tuple<sequence<MWavePerBlk, M_Lane>, // second direction
|
||||
sequence<K_Lane, 1>>, // first direction
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_FlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
|
||||
constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
|
||||
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
|
||||
constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
|
||||
constexpr index_t M_Wrap = TileShape::BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWavePerBlk = N_Warp;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<M_Wrap>, // ?
|
||||
tuple<sequence<NWavePerBlk, N_Lane>, // second direction
|
||||
sequence<K_Lane, 1>>, // first direction
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>, // which direction
|
||||
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<2>,
|
||||
sequence<1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -148,6 +148,9 @@ template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16,
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<WGAttrNumAccessEnum::Quad>; };
|
||||
|
||||
template<> struct WarpGemmDispatcher<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_16x16x128_fp4<WGAttrNumAccessEnum::Quad>; };
|
||||
|
||||
//WMMA cases
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 16, TransposeC, false> { using Type =WarpGemmWmma_f32_16x16x16_f8_f8<TransposeC>; };
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 16, TransposeC, false> { using Type =WarpGemmWmma_f32_16x16x16_bf8_bf8<TransposeC>; };
|
||||
|
||||
Reference in New Issue
Block a user