mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4821 (commit 9456e0f)
[CK TILE] Refactor MX FLATMM example Refactor the MX FLATMM example to support more pipelines across different architectures. This work facilitates the NPI team roadmap.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
711374fcab
commit
b661eab573
@@ -20,7 +20,7 @@ static constexpr inline auto is_row_major(Layout layout_)
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
template <typename MXFlatmmArchTraits,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
@@ -49,6 +49,8 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
using FlatmmConfig = typename MXFlatmmArchTraits::Config;
|
||||
|
||||
ck_tile::ScaleFlatmmHostArgs<ScaleA, ScaleB> args = {a_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
@@ -99,7 +101,7 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
constexpr auto has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_num_v = tail_num_.value;
|
||||
auto invoke_splitk_path = [&](auto split_k_) {
|
||||
return mx_flatmm_calc<FlatmmConfig,
|
||||
return mx_flatmm_calc<MXFlatmmArchTraits,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
@@ -157,22 +159,22 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert(
|
||||
"mx_prec", "fp4xfp4", "data type for activation and weight, support: fp4xfp4, fp8xfp8")
|
||||
.insert("mx_prec",
|
||||
"fp4xfp4",
|
||||
"data type for activation and weight, support: fp4xfp4, fp6xfp6, fp8xfp8, fp8xfp4 "
|
||||
"and fp4xfp8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:constant(1)")
|
||||
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
.insert("warp_tile", "0", "0: 16x16x128 on gfx950.");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <ck_tile::index_t N_Warp_Tile, typename dtype>
|
||||
template <ck_tile::index_t NLane, typename dtype>
|
||||
auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
|
||||
{
|
||||
auto src_lengths = src.get_lengths();
|
||||
@@ -181,8 +183,8 @@ auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
|
||||
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
|
||||
int KPack =
|
||||
std::is_same_v<dtype, ck_tile::pk_fp6x16_t> ? 32 : 16 * packed_size; // fp4/fp6:32 or fp8:16
|
||||
int NLane = N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
|
||||
int KLane = ck_tile::get_warp_size() / NLane;
|
||||
int K0 = K / (KLane * KPack);
|
||||
|
||||
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));
|
||||
@@ -211,68 +213,10 @@ auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
|
||||
return shuffled;
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, bool KLast, typename dtype>
|
||||
auto preShuffleScale(ck_tile::HostTensor<dtype>& src)
|
||||
{
|
||||
auto src_lengths = src.get_lengths();
|
||||
const auto MN = KLast ? src_lengths[0] : src_lengths[1];
|
||||
const auto K = KLast ? src_lengths[1] : src_lengths[0];
|
||||
|
||||
size_t MNXdlPack = 2;
|
||||
size_t KXdlPack = 2;
|
||||
size_t XdlMNThread = FlatmmConfig::N_Warp_Tile; // 16
|
||||
size_t XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
const auto MN_Paded = ck_tile::integer_least_multiple(MN, XdlMNThread * MNXdlPack);
|
||||
|
||||
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({MN_Paded * K}, {1}));
|
||||
|
||||
size_t K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(size_t n = 0; n < MN_Paded; ++n)
|
||||
{
|
||||
for(size_t k = 0; k < K; ++k)
|
||||
{
|
||||
auto n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
auto tempn = n % (XdlMNThread * MNXdlPack);
|
||||
auto n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
auto n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
auto k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
auto tempk = k % (XdlKThread * KXdlPack);
|
||||
auto k1 = tempk % XdlKThread; // i XdlKThread
|
||||
auto k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
auto outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
|
||||
if constexpr(KLast)
|
||||
shuffled(outputIndex) = n < MN ? src(n, k) : dtype{};
|
||||
else
|
||||
shuffled(outputIndex) = n < MN ? src(k, n) : dtype{};
|
||||
}
|
||||
}
|
||||
return shuffled;
|
||||
}
|
||||
|
||||
#include "run_mx_flatmm.inc"
|
||||
|
||||
int run_mx_flatmm_example(int argc, char* argv[])
|
||||
int run_mx_flatmm_example(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
@@ -281,6 +225,8 @@ int run_mx_flatmm_example(int argc, char* argv[])
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
int persistent_opt = arg_parser.get_int("persistent");
|
||||
|
||||
std::cout << "Using default warptile of 16x16x128." << std::endl;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(mx_prec == "fp4" || mx_prec == "fp4xfp4")
|
||||
@@ -289,8 +235,8 @@ int run_mx_flatmm_example(int argc, char* argv[])
|
||||
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::fp16_t,
|
||||
MXfp4_FlatmmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
MXFlatmm_GFX950_FP4FP4_Traits,
|
||||
false>(arg_parser, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only non-persistent kernels are supported currently!");
|
||||
}
|
||||
@@ -300,8 +246,8 @@ int run_mx_flatmm_example(int argc, char* argv[])
|
||||
return run_mx_flatmm_with_layouts<ck_tile::pk_fp6x16_t,
|
||||
ck_tile::pk_fp6x16_t,
|
||||
ck_tile::fp16_t,
|
||||
MXfp6_FlatmmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
MXFlatmm_GFX950_FP6FP6_Traits,
|
||||
false>(arg_parser, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only support non-persistent kernel now!");
|
||||
}
|
||||
@@ -311,8 +257,8 @@ int run_mx_flatmm_example(int argc, char* argv[])
|
||||
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp16_t,
|
||||
MXfp8_FlatmmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
MXFlatmm_GFX950_FP8FP8_Traits,
|
||||
false>(arg_parser, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only support non-persistent kernel now!");
|
||||
}
|
||||
@@ -322,8 +268,8 @@ int run_mx_flatmm_example(int argc, char* argv[])
|
||||
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::fp16_t,
|
||||
MXf8f4_FlatmmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
MXFlatmm_GFX950_FP8FP4_Traits,
|
||||
false>(arg_parser, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only support non-persistent kernel now!");
|
||||
}
|
||||
@@ -333,8 +279,8 @@ int run_mx_flatmm_example(int argc, char* argv[])
|
||||
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp16_t,
|
||||
MXf4f8_FlatmmConfig16,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
MXFlatmm_GFX950_FP4FP8_Traits,
|
||||
false>(arg_parser, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only support non-persistent kernel now!");
|
||||
}
|
||||
@@ -359,7 +305,7 @@ int main(int argc, char* argv[])
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return run_mx_flatmm_example(argc, argv);
|
||||
return run_mx_flatmm_example(arg_parser);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
|
||||
@@ -11,167 +11,9 @@
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct MXfp4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 512;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
#include "mx_flatmm_arch_traits.hpp"
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct MXfp6_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct MXfp8_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct MXf8f4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
struct MXf4f8_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
template <typename MXFlatmmArchTraits,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
|
||||
137
example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp
Normal file
137
example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp
Normal file
@@ -0,0 +1,137 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace core {
|
||||
namespace arch {
|
||||
|
||||
// Use the amdgcn_target_id enum from arch.hpp
|
||||
using TargetId = amdgcn_target_id;
|
||||
|
||||
} // namespace arch
|
||||
} // namespace core
|
||||
} // namespace ck_tile
|
||||
|
||||
// Base FlatmmConfig with 16x16 warp tile (for non-GFX1250)
|
||||
struct MXFlatmmConfigBase16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct MXfp4_FlatmmConfig16 : public MXFlatmmConfigBase16
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 512;
|
||||
};
|
||||
|
||||
// Architecture traits for MX Flatmm - Primary template (gfx950 implementation)
|
||||
template <ck_tile::core::arch::TargetId Arch, typename FlatmmConfig>
|
||||
struct MXFlatmmArchTraits
|
||||
{
|
||||
static constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
using Config = FlatmmConfig;
|
||||
|
||||
template <typename MXPipelineProblem>
|
||||
using MXFlatmmPipeline = ck_tile::MXFlatmmPipelineAGmemBGmemCRegV1<MXPipelineProblem>;
|
||||
|
||||
static constexpr int GetNLane() { return Config::N_Warp_Tile; }
|
||||
|
||||
template <bool KLast, typename dtype>
|
||||
static auto preShuffleScale(ck_tile::HostTensor<dtype>& src)
|
||||
{
|
||||
auto src_lengths = src.get_lengths();
|
||||
const auto MN = KLast ? src_lengths[0] : src_lengths[1];
|
||||
const auto K = KLast ? src_lengths[1] : src_lengths[0];
|
||||
|
||||
size_t MNXdlPack = 2;
|
||||
size_t KXdlPack = 2;
|
||||
size_t XdlMNThread = Config::N_Warp_Tile; // 16
|
||||
size_t XdlKThread = ck_tile::get_warp_size() / XdlMNThread;
|
||||
|
||||
const auto MN_Paded = ck_tile::integer_least_multiple(MN, XdlMNThread * MNXdlPack);
|
||||
|
||||
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({MN_Paded * K}, {1}));
|
||||
|
||||
size_t K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(size_t n = 0; n < MN_Paded; ++n)
|
||||
{
|
||||
for(size_t k = 0; k < K; ++k)
|
||||
{
|
||||
auto n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
auto tempn = n % (XdlMNThread * MNXdlPack);
|
||||
auto n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
auto n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
auto k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
auto tempk = k % (XdlKThread * KXdlPack);
|
||||
auto k1 = tempk % XdlKThread; // i XdlKThread
|
||||
auto k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
auto outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread +
|
||||
n1 * MNXdlPack * KXdlPack + k2 * MNXdlPack + n2;
|
||||
|
||||
if constexpr(KLast)
|
||||
shuffled(outputIndex) = n < MN ? src(n, k) : dtype{};
|
||||
else
|
||||
shuffled(outputIndex) = n < MN ? src(k, n) : dtype{};
|
||||
}
|
||||
}
|
||||
return shuffled;
|
||||
}
|
||||
};
|
||||
|
||||
using MXFlatmm_GFX950_FP4FP4_Traits =
|
||||
MXFlatmmArchTraits<ck_tile::core::arch::TargetId::GFX950, MXfp4_FlatmmConfig16>;
|
||||
using MXFlatmm_GFX950_FP8FP8_Traits =
|
||||
MXFlatmmArchTraits<ck_tile::core::arch::TargetId::GFX950, MXFlatmmConfigBase16>;
|
||||
using MXFlatmm_GFX950_FP6FP6_Traits =
|
||||
MXFlatmmArchTraits<ck_tile::core::arch::TargetId::GFX950, MXFlatmmConfigBase16>;
|
||||
using MXFlatmm_GFX950_FP8FP4_Traits =
|
||||
MXFlatmmArchTraits<ck_tile::core::arch::TargetId::GFX950, MXFlatmmConfigBase16>;
|
||||
using MXFlatmm_GFX950_FP4FP8_Traits =
|
||||
MXFlatmmArchTraits<ck_tile::core::arch::TargetId::GFX950, MXFlatmmConfigBase16>;
|
||||
@@ -6,31 +6,33 @@ function(mx_flatmm_instance_generate FILE_LIST)
|
||||
set(A_LAYOUT ROW)
|
||||
set(B_LAYOUT COL)
|
||||
set(C_LAYOUT ROW)
|
||||
set(FLATMM_CONFIG_FP4xFP4 "MXfp4_FlatmmConfig16")
|
||||
set(FLATMM_CONFIG_FP8xFP8 "MXfp8_FlatmmConfig16")
|
||||
set(FLATMM_CONFIG_FP6xFP6 "MXfp6_FlatmmConfig16")
|
||||
set(FLATMM_CONFIG_FP8xFP4 "MXf8f4_FlatmmConfig16")
|
||||
set(FLATMM_CONFIG_FP4xFP8 "MXf4f8_FlatmmConfig16")
|
||||
|
||||
set(MXFLATMM_ARCH)
|
||||
|
||||
if (GPU_TARGETS MATCHES "gfx95")
|
||||
list(APPEND MXFLATMM_ARCH MXFlatmm_GFX950_)
|
||||
endif()
|
||||
|
||||
# foreach(PERSISTENT false true)
|
||||
# TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions.
|
||||
foreach(PERSISTENT false)
|
||||
foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP6xFP6 FP8xFP4 FP4xFP8)
|
||||
set(FLATMM_CONFIG ${FLATMM_CONFIG_${DATA_TYPE}})
|
||||
string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE})
|
||||
list(GET DATA_TYPE_AB 0 A_DATA_TYPE)
|
||||
list(GET DATA_TYPE_AB 1 B_DATA_TYPE)
|
||||
|
||||
foreach(SPLIT_K false true)
|
||||
foreach(HAS_HOT_LOOP false true)
|
||||
foreach(TAIL_NUMBER ODD EVEN)
|
||||
set(KERNEL_FILE mxgemm/mx_flatmm_instance_${PERSISTENT}_${DATA_TYPE}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp)
|
||||
string(TOLOWER ${KERNEL_FILE} KERNEL_FILE)
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}
|
||||
@ONLY)
|
||||
list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE})
|
||||
foreach(ARCH ${MXFLATMM_ARCH})
|
||||
set(MXFLATMM_ARCH_TRAITS "${ARCH}${A_DATA_TYPE}${B_DATA_TYPE}_Traits")
|
||||
foreach(SPLIT_K false true)
|
||||
foreach(HAS_HOT_LOOP false true)
|
||||
foreach(TAIL_NUMBER ODD EVEN)
|
||||
set(KERNEL_FILE mxgemm/instance_${ARCH}${DATA_TYPE}_${PERSISTENT}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp)
|
||||
string(TOLOWER ${KERNEL_FILE} KERNEL_FILE)
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}
|
||||
@ONLY)
|
||||
list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE})
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#include "mx_flatmm_instance.hpp"
|
||||
|
||||
// clang-format off
|
||||
#define FLATMM_CONFIG @FLATMM_CONFIG@
|
||||
#define MXFLATMM_ARCH_TRAITS @MXFLATMM_ARCH_TRAITS@
|
||||
#define A_DATA_TYPE @A_DATA_TYPE@
|
||||
#define B_DATA_TYPE @B_DATA_TYPE@
|
||||
#define C_DATA_TYPE @C_DATA_TYPE@
|
||||
@@ -37,7 +37,7 @@ inline constexpr int ScaleGranularityK = 32;
|
||||
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>;
|
||||
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>;
|
||||
|
||||
template float mx_flatmm_calc<FLATMM_CONFIG,
|
||||
template float mx_flatmm_calc<MXFLATMM_ARCH_TRAITS,
|
||||
A_DATA_TYPE,
|
||||
B_DATA_TYPE,
|
||||
/*DsDatatype*/ ck_tile::tuple<>,
|
||||
|
||||
@@ -16,7 +16,7 @@ template <typename Layout>
|
||||
using is_row_major_t = ck_tile::bool_constant<
|
||||
std::is_same_v<ck_tile::remove_cvref_t<Layout>, ck_tile::tensor_layout::gemm::RowMajor>>;
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
template <typename MXFlatmmArchTraits,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
@@ -36,6 +36,8 @@ template <typename FlatmmConfig,
|
||||
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using FlatmmConfig = typename MXFlatmmArchTraits::Config;
|
||||
|
||||
using FlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
@@ -63,7 +65,8 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
ck_tile::ignore = Splitk;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
// determined by scale shuffle pattern
|
||||
constexpr int BlockedXDLN_PerWarp = MXFlatmmArchTraits::BlockedXDLN_PerWarp;
|
||||
|
||||
using MXPipelineProblem = ck_tile::MXFlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -74,7 +77,8 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
HasHotLoop,
|
||||
TailNum>;
|
||||
|
||||
using MXFlatmmPipeline = ck_tile::MXFlatmmPipelineAGmemBGmemCRegV1<MXPipelineProblem>;
|
||||
using MXFlatmmPipeline =
|
||||
typename MXFlatmmArchTraits::template MXFlatmmPipeline<MXPipelineProblem>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<FlatmmShape,
|
||||
|
||||
@@ -4,21 +4,16 @@
|
||||
template <typename PrecActType,
|
||||
typename PrecWeightType,
|
||||
typename CDataType,
|
||||
typename FlatmmConfig,
|
||||
typename MXFlatmmArchTraits,
|
||||
bool UsePersistentKernel = false,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_mx_flatmm_with_layouts(int argc,
|
||||
char* argv[],
|
||||
int run_mx_flatmm_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using AccDataType = float;
|
||||
@@ -111,9 +106,9 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
}
|
||||
}
|
||||
|
||||
const auto b_shuffled_host = preShuffleWeight<FlatmmConfig::N_Warp_Tile>(b_origin_host);
|
||||
const auto scale_a_shuffled = preShuffleScale<FlatmmConfig, true>(scale_a);
|
||||
const auto scale_b_shuffled = preShuffleScale<FlatmmConfig, false>(scale_b);
|
||||
const auto b_shuffled_host = preShuffleWeight<MXFlatmmArchTraits::GetNLane()>(b_origin_host);
|
||||
const auto scale_a_shuffled = MXFlatmmArchTraits::template preShuffleScale<true>(scale_a);
|
||||
const auto scale_b_shuffled = MXFlatmmArchTraits::template preShuffleScale<false>(scale_b);
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_shuffled_dev_buf(b_shuffled_host.get_element_space_size_in_bytes());
|
||||
@@ -135,7 +130,7 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
|
||||
|
||||
invoke_mx_flatmm<FlatmmConfig,
|
||||
invoke_mx_flatmm<MXFlatmmArchTraits,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
|
||||
Reference in New Issue
Block a user