cherry pick related code

This commit is contained in:
lalala-sh
2025-09-22 08:05:55 +00:00
parent 427dca076b
commit d454d0e201
41 changed files with 10947 additions and 1112 deletions

View File

@@ -1,6 +1,28 @@
add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp)
add_executable(tile_example_mixed_prec_flatmm EXCLUDE_FROM_ALL mixed_prec/mixed_prec_flatmm.cpp)
add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp)
add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp)
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter)
set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-local-typedef -Wno-unused-variable -Wno-unused-parameter)
list(APPEND EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS -Wno-nrvo -Wno-unused-variable -Wno-unused-parameter -Wno-unused-local-typedef -Wno-float-equal)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
list(APPEND EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS --save-temps -Wno-nrvo)
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
target_compile_options(tile_example_mixed_prec_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
list(APPEND EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS --save-temps)
target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS})
target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS})

View File

@@ -7,7 +7,7 @@ This folder contains example for FLATMM using ck_tile tile-programming implement
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the flatmm calculation
make tile_example_flatmm_basic -j
```

View File

@@ -11,7 +11,102 @@
#include "ck_tile/host.hpp"
#include "flatmm_basic.hpp"
#include "run_flatmm_example.inc"
#include <type_traits>
template <typename T>
constexpr const char* DataTypeToString()
{
if constexpr(std::is_same_v<T, ck_tile::half_t>)
{
return "fp16";
}
else if constexpr(std::is_same_v<T, ck_tile::fp8_t>)
{
return "fp8";
}
else if constexpr(std::is_same_v<T, ck_tile::bf8_t>)
{
return "bf8";
}
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
return "bf16";
}
else
{
return "unknown";
}
}
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// mfma_type, 0:32x32, 1:16x16
template <typename FlatmmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int MaxVecSize = 16 / sizeof(T);
constexpr int KLane = ck_tile::get_warp_size() / FlatmmConfig::N_Warp_Tile;
constexpr int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane);
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / ItemsPerAccess,
ItemsPerAccess});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 1, 3});
}
template <typename FlatmmConfig, typename T>
auto shuffle_b_v1(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int MaxVecSize = 16 / sizeof(T);
constexpr int KLane = ck_tile::get_warp_size() / FlatmmConfig::N_Warp_Tile;
constexpr int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane);
constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile / FlatmmConfig::N_Warp;
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Tile,
FlatmmConfig::N_Warp,
FlatmmConfig::N_Warp_Tile,
NRepeat,
k_ / ItemsPerAccess,
ItemsPerAccess});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 2, 5});
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename FlatmmConfig,
typename ADataType,
@@ -23,9 +118,12 @@ template <typename FlatmmConfig,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise>
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s)
float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s)
{
using CodegenFlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
@@ -80,14 +178,14 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using CodegenPipelineProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenFlatmmPipeline =
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
@@ -101,6 +199,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
DsLayout,
ELayout,
CDEElementWise,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
FlatmmConfig::M_Warp,
@@ -110,7 +209,10 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups>>;
FlatmmConfig::NumWaveGroups,
false,
1,
FlatmmConfig::TiledMMAPermuteN>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
@@ -118,8 +220,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(kargs);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -167,16 +269,18 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
@@ -201,6 +305,111 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_c
return ave_time;
}
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ScaleM,
typename ScaleN,
bool UsePersistentKernel = false,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::DeviceMem& b_shuffle_dev_buf,
ck_tile::DeviceMem& c_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
ScaleM scale_m,
ScaleN scale_n,
int n_warmup,
int n_repeat)
{
ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C,
scale_m,
scale_n};
float ave_time = flatmm_calc<FlatmmConfig,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
ScaleM,
ScaleN,
UsePersistentKernel,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>()
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "256", "m dimension")
.insert("n", "256", "n dimension")
.insert("k", "128", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
.insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
.insert("warp_tile",
"0",
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
#include "run_flatmm_example.inc"
template <template <typename PreType> typename FlatmmConfig>
int run_flatmm_example(int argc, char* argv[])
{
@@ -214,20 +423,10 @@ int run_flatmm_example(int argc, char* argv[])
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
int k = arg_parser.get_int("k");
int stride_b = arg_parser.get_int("stride_b");
if(b_layout == "C" && stride_b > k)
{
throw std::runtime_error(
"For ColumnMajor layout, StrideB must be smaller than or equal to K (" +
std::to_string(k) + ")");
}
int scale_opt = arg_parser.get_int("scale");
int persistent_opt = arg_parser.get_int("persistent");
if(a_layout == "R" && b_layout == "C")
{
if(data_type == "fp16")
{
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
@@ -240,13 +439,53 @@ int run_flatmm_example(int argc, char* argv[])
}
else if(data_type == "fp8")
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
argc, argv, Row{}, Col{}, Row{});
if(scale_opt == 0)
{
if(persistent_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::fp8_t,
FlatmmConfig<ck_tile::fp8_t>,
-1,
-1,
true>(argc, argv, Row{}, Col{}, Row{});
}
}
else
{
if(persistent_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::fp8_t,
FlatmmConfig<ck_tile::fp8_t>,
1,
1>(argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::fp8_t,
FlatmmConfig<ck_tile::fp8_t>,
1,
1,
true>(argc, argv, Row{}, Col{}, Row{});
}
}
}
else if(data_type == "bf8")
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
if(scale_opt == 0)
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1, 1>(
argc, argv, Row{}, Col{}, Row{});
}
}
else
{

View File

@@ -35,12 +35,13 @@ struct FlatmmConfig32
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 2;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool TiledMMAPermuteN = false; // disable PermuteN when NWarpTile != 16
};
template <typename DataType>
@@ -72,26 +73,28 @@ struct FlatmmConfig16
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 2;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 4 == 0;
};
template <typename DataType>
struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
{
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
};
static constexpr int kBlockPerCu = 1;
template <typename DataType>
struct FlatmmConfig16_Wmma : public FlatmmConfig16<DataType>
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr int N_Repeat =
N_Tile / FlatmmConfig16<DataType>::N_Warp_Tile / FlatmmConfig16<DataType>::N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 4 == 0;
};
template <typename ADataType>
@@ -172,42 +175,19 @@ struct is_8bit_type
{
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "256", "m dimension")
.insert("n", "256", "n dimension")
.insert("k", "128", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
#if !defined(CK_TILE_USE_WMMA)
.insert(
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
#endif
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "flatmm_basic.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// host API
template <typename ADataType,
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename FlatmmConfig,
typename ALayout,
typename BLayout,
typename CLayout>
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s);
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise>
float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s);

View File

@@ -0,0 +1,50 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
// GEMM config with 16x16 warp tile
struct A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr int kBlockPerCu = 1;
static constexpr int N_Repeat =
N_Tile / A16W4_FlatmmConfig16::N_Warp_Tile / A16W4_FlatmmConfig16::N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};

View File

@@ -0,0 +1,513 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <memory>
#include "a16w4_moe_flatmm.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/moe_flatmm.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/reference/reference_moe_gemm.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// gemm1
// operand-A = [num_token, d_model]
// operand-B = [num_expert, hidden, d_model]
// operand-C = [num_token, topk, hidden]
// gemm2
// operand-A = [num_token, topk, hidden]
// operand-B = [num_expert, d_model, hidden]
// operand-C = [num_token, d_model]
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ck_tile::MoeFlatmmKind moe_kind = ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only,
typename CDEElementWise = ck_tile::element_wise::PassThrough,
typename MoeFlatmmHostArgs>
float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config& s)
{
using CodegenFlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
FlatmmConfig::TileParitionerGroupNum,
FlatmmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
ALayout,
BLayout,
ELayout,
FlatmmConfig::NumWaveGroups>;
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
FlatmmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
FlatmmConfig::TransposeC,
FlatmmConfig::UseStructuredSparsity,
false, // UsePersistentKernel_
FlatmmConfig::NumWaveGroups,
true>; // Preshuffle_
constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, ck_tile::pk_fp4_t>;
if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
{
static_assert(
FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0,
"requires NRepeat is multiple of 2 for FFN_gemm1_gate_up");
}
using ComputeDataType = ADataType;
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<ComputeDataType,
ComputeDataType,
AccDataType,
CodegenFlatmmShape,
Traits>;
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using CodegenPipelineProblem =
std::conditional_t<MXFP4_Pipeline,
ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>,
ck_tile::FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>>;
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
ComputeDataType,
DsDatatype,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
FlatmmConfig::M_Warp,
FlatmmConfig::N_Warp,
FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups,
false,
1,
FlatmmConfig::TiledMMAPermuteN,
BlockedXDLN_PerWarp>>;
using CodegenFlatmmPipeline = std::conditional_t<
MXFP4_Pipeline,
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>,
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>>;
using FusedAct =
std::conditional_t<MXFP4_Pipeline, ck_tile::moe::Swiglu, ck_tile::moe::MoeSilu>;
using Kernel = ck_tile::MoeFlatmmKernel<TilePartitioner,
CodegenFlatmmPipeline,
GemmEpilogue,
moe_kind,
FusedAct>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>
? 2
: 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>
? 2
: 1;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2 ? args.NumTokens * args.TopK
: args.NumTokens,
args.K,
args.stride_A,
is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N * args.NumExperts, args.stride_B, is_row_major(BLayout{})));
const int outputN =
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? args.N / 2 : args.N;
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.NumTokens * args.N * sizeof(CDataType), s.stream_id_));
else if(args.k_batch > 1)
hipGetErrorString(
hipMemsetAsync(args.e_ptr,
0,
args.NumTokens * args.TopK * outputN * sizeof(CDataType),
s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
template <class FlatmmConfig, ck_tile::MoeFlatmmKind moe_kind, class IterSrc, class IterDst>
void shuffle_mxfp4_weight(const IterSrc src, IterDst dst, int experts_cnt, int N, int K)
{
int KPack = 16;
int NLane = FlatmmConfig::N_Warp_Tile;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
{
int up_stride = N / 2 / NLane;
for(long eid = 0; eid < experts_cnt; ++eid)
{
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K_pk; ++k)
{
int n0 = n / NLane;
int n1 = n % NLane;
// interleave gate and up part with granularity is 16.
int n0_interleave = n >= N / 2 ? (n0 - up_stride) * 2 + 1 : // up part
n0 * 2; // gate part
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
long outputIndex = eid * N * K_pk + n0_interleave * KPack * NLane * KLane * K0 +
k0 * KPack * NLane * KLane + k1 * KPack * NLane +
n1 * KPack + k2;
dst[outputIndex] = src[eid * N * K_pk + n * K_pk + k];
}
}
}
}
else
{
for(long eid = 0; eid < experts_cnt; ++eid)
{
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K_pk; ++k)
{
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
long outputIndex = eid * N * K_pk + n0 * KPack * NLane * KLane * K0 +
k0 * KPack * NLane * KLane + k1 * KPack * NLane +
n1 * KPack + k2;
dst[outputIndex] = src[eid * N * K_pk + n * K_pk + k];
}
}
}
}
}
template <typename FlatmmConfig, ck_tile::MoeFlatmmKind moe_kind, typename T>
auto shuffle_mxfp4_scale(const ck_tile::HostTensor<T>& scale, int experts_cnt)
{
assert(scale.get_lengths().size() == 2);
int n_ = scale.get_lengths()[1];
int k_ = scale.get_lengths()[0];
int k_per_expert = k_ / experts_cnt;
constexpr int K_Pack = 2; // fixed for mxfp4
constexpr int N_Pack = 2; // fixed for mxfp4
constexpr int GranularityK = 32; // fixed for mxfp4
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16");
static_assert(FlatmmConfig::N_Repeat % N_Pack == 0);
static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0);
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
{
ck_tile::HostTensor<T> shfl_scale({
experts_cnt,
k_per_expert / K_Pack / K_Lane,
K_Pack,
K_Lane,
N_Pack, // N_Pack = 2 is composed of Gate + Up.
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
FlatmmConfig::N_Warp_Tile,
});
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
return ck_tile::reference_permute(shfl_scale, {0, 5, 1, 3, 6, 2, 4});
}
else
{
ck_tile::HostTensor<T> shfl_scale({
experts_cnt,
k_per_expert / K_Pack / K_Lane,
K_Pack,
K_Lane,
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
N_Pack,
FlatmmConfig::N_Warp_Tile,
});
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
return ck_tile::reference_permute(shfl_scale, {0, 4, 1, 3, 6, 2, 5});
}
}
#include "run_a16w4_moe_flatmm_example.inc"
template <typename FlatmmConfig>
int run_a16w4_moe_flatmm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string mixed_prec = arg_parser.get_str("mixed_prec");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C")
{
const std::string gemm_kind = arg_parser.get_str("gemm_kind");
if(gemm_kind == "gemm1_gate_up")
{
if(mixed_prec == "fp16xfp4")
{
return run_a16w4_moe_gemm_example_with_layouts<
ck_tile::half_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else if(mixed_prec == "bf16xfp4")
{
return run_a16w4_moe_gemm_example_with_layouts<
ck_tile::bfloat16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
}
}
else if(gemm_kind == "gemm2")
{
if(mixed_prec == "fp16xfp4")
{
return run_a16w4_moe_gemm_example_with_layouts<ck_tile::half_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
argc, argv, Row{}, Col{}, Row{});
}
else if(mixed_prec == "bf16xfp4")
{
return run_a16w4_moe_gemm_example_with_layouts<ck_tile::bfloat16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported precision type for gemm2!");
}
}
else
{
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
"[gemm1_gate_up | gemm2]");
}
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
return -1;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
try
{
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16>(argc, argv);
}
// else if(warp_tile == 1)
// {
// return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
// }
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,87 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/moe_flatmm.hpp"
// GEMM config with 16x16 warp tile
struct A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr int kBlockPerCu = 1;
static constexpr int N_Repeat =
N_Tile / A16W4_FlatmmConfig16::N_Warp_Tile / A16W4_FlatmmConfig16::N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("experts", "8", "Num of experts - 8 by default")
.insert("NumTokens", "128", "M dimensions - 128 by default.")
.insert("TopK", "3", "Top K - 3 by default.")
.insert("N", "4096", "N dimensions - 4096 by default.")
.insert("K", "4096", "K dimensions - 4096 by default.")
.insert("stride_A", "", "Tensor A strides - it is empty by default.")
.insert("stride_B", "", "Tensor B strides - it is empty by default.")
.insert("stride_C", "", "Tensor C strides - it is empty by default.")
.insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "C", "B tensor data layout - Col by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("gemm_kind",
"gemm1_gate_up",
"Gemm kind in FFN network [gemm1_gate_up | gemm2] - "
"gemm1_gate_up by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("mixed_prec",
"bf16xfp4",
"data type for activation and weight, support: bf16xfp4, fp16xfp4")
.insert("init", "0", "0:random, 1:constant(1)")
.insert("warp_tile",
"0",
"0: 16x16, 1: 16x16 (950 only, may use a larger tile than warp_tile=0)")
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}

View File

@@ -0,0 +1,484 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <type_traits>
#include "ck_tile/host.hpp"
#include "mixed_prec_flatmm.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise>
float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s)
{
using CodegenFlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
FlatmmConfig::TileParitionerGroupNum,
FlatmmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
ALayout,
BLayout,
ELayout,
FlatmmConfig::NumWaveGroups>;
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
FlatmmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
FlatmmConfig::TransposeC,
FlatmmConfig::UseStructuredSparsity,
persistent,
FlatmmConfig::NumWaveGroups,
true>;
using ComputeDataType = ADataType;
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<ComputeDataType,
ComputeDataType,
AccDataType,
CodegenFlatmmShape,
Traits>;
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
using CodegenPipelineProblem = ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenFlatmmPipeline =
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
ComputeDataType,
DsDatatype,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
FlatmmConfig::M_Warp,
FlatmmConfig::N_Warp,
FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups,
false, // FixedVectorSize
1, // VectorSizeC
FlatmmConfig::TiledMMAPermuteN,
BlockedXDLN_PerWarp>>;
using Kernel =
ck_tile::F16xMXF4FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ScaleN,
bool UsePersistentKernel = false,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_mixed_prec_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::DeviceMem& b_shuffle_dev_buf,
ck_tile::DeviceMem& c_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
ScaleN dequant_scale_n,
int n_warmup,
int n_repeat)
{
// Activation has no scale
using ActScaleType = ck_tile::FlatmmScalePointer<-1>;
ck_tile::ScaleFlatmmHostArgs<ActScaleType, ScaleN> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C,
{},
dequant_scale_n};
float ave_time = mixed_prec_flatmm_calc<FlatmmConfig,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
ActScaleType,
ScaleN,
UsePersistentKernel,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
constexpr int PackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * N * K / PackedSize +
sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run A16W4_Flatmm kernel "
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "256", "m dimension")
.insert("n", "256", "n dimension")
.insert("k", "512", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Row by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "1", "0. No validation, 1. Validation on GPU")
.insert("mixed_prec",
"bf16xfp4",
"data type for activation and weight, support: bf16xfp4, fp16xfp4")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:constant(1)")
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
.insert("warp_tile",
"0",
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <class FlatmmConfig, class IterSrc, class IterDst>
void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K)
{
int KPack = 16;
int NLane = FlatmmConfig::N_Warp_Tile;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K_pk; ++k)
{
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K_pk + k];
}
}
}
template <class FlatmmConfig, class T>
auto preShuffleScale(const ck_tile::HostTensor<T>& scale)
{
assert(scale.get_lengths().size() == 2);
int n_ = scale.get_lengths()[1];
int k_ = scale.get_lengths()[0];
constexpr int K_Pack = 2; // fixed for mxfp4
constexpr int N_Pack = 2; // fixed for mxfp4
constexpr int GranularityK = 32; // fixed for mxfp4
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16");
static_assert(FlatmmConfig::N_Repeat % N_Pack == 0);
static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0);
ck_tile::HostTensor<T> shfl_scale({
k_ / K_Pack / K_Lane,
K_Pack,
K_Lane,
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
N_Pack,
FlatmmConfig::N_Warp_Tile,
});
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
return ck_tile::reference_permute(shfl_scale, {3, 0, 2, 5, 1, 4});
}
#include "run_mixed_prec_flatmm.inc"
template <typename FlatmmConfig>
int run_mixed_prec_flatmm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
std::string mixed_prec = arg_parser.get_str("mixed_prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
int persistent_opt = arg_parser.get_int("persistent");
if(a_layout == "R" && b_layout == "C")
{
if(mixed_prec == "bf16xfp4")
{
if(persistent_opt == 0)
{
run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
false>(argc, argv, Row{}, Col{}, Row{});
}
else
{
run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
true>(argc, argv, Row{}, Col{}, Row{});
}
}
else if(mixed_prec == "fp16xfp4")
{
if(persistent_opt == 0)
{
run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
false>(argc, argv, Row{}, Col{}, Row{});
}
else
{
run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
true>(argc, argv, Row{}, Col{}, Row{});
}
}
else
{
throw std::runtime_error("Unsupported data_type!");
}
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
return -1;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
try
{
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16>(argc, argv);
}
else if(warp_tile == 1)
{
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
}
else
{
throw std::runtime_error("Unsupported warp_tile!");
}
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,15 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "a16w4_flatmm.hpp"

View File

@@ -0,0 +1,356 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ck_tile::MoeFlatmmKind kind,
typename CDEElementWise = ck_tile::element_wise::PassThrough,
typename MoeHostArgs>
float invoke_a16w4_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)
{
float ave_time = a16w4_moe_gemm<FlatmmConfig,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
kind,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
std::string op_name{"Moe Gemm"};
constexpr int PackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
sizeof(BDataType) * args.N * args.K / PackedSize +
sizeof(CDataType) * args.M * args.N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
return ave_time;
}
template <typename PrecActType,
typename PrecWeightType,
typename FlatmmConfig,
ck_tile::MoeFlatmmKind kind,
typename ALayout,
typename BLayout,
typename CLayout>
int run_a16w4_moe_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
};
using ADataType = PrecActType;
using BDataType = PrecWeightType;
using CDataType = PrecActType;
using AccDataType = float;
using ScaleType = ck_tile::e8m0_t;
constexpr int ScaleGranularityN = 1;
constexpr int ScaleGranularityK = 32;
const ck_tile::index_t N = arg_parser.get_int("N");
const ck_tile::index_t K = arg_parser.get_int("K");
ck_tile::index_t stride_A = arg_parser.get_int("stride_A");
ck_tile::index_t stride_B = arg_parser.get_int("stride_B");
ck_tile::index_t stride_C = arg_parser.get_int("stride_C");
ck_tile::index_t init_method = arg_parser.get_int("init");
const ck_tile::index_t num_tokens = arg_parser.get_int("NumTokens");
const ck_tile::index_t topk = arg_parser.get_int("TopK");
const ck_tile::index_t warmup = arg_parser.get_int("warmup");
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
const ck_tile::index_t experts = arg_parser.get_int("experts");
// TODO: replace the magic declaration
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
ck_tile::index_t sorted_tile_num = (num_tokens + MPerBlock - 1) / MPerBlock * MPerBlock * topk;
ck_tile::index_t valid_tile_num = sorted_tile_num;
ck_tile::index_t sorted_size = sorted_tile_num * MPerBlock;
const ck_tile::index_t M = sorted_tile_num * MPerBlock;
const ck_tile::index_t outputN = kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? N / 2 : N;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr bool IsInputGemm = kind != ck_tile::MoeFlatmmKind::kFFN_gemm2;
stride_A = ck_tile::get_default_stride(
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{}));
auto a_m_k_tensor = ck_tile::HostTensor<ADataType>(ck_tile::host_tensor_descriptor(
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout)));
auto b_k_n_tensor = ck_tile::HostTensor<BDataType>(
is_row_major(b_layout)
? ck_tile::host_tensor_descriptor(experts * N, K, stride_B, is_row_major(b_layout))
: ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
auto c_m_n_tensor = ck_tile::HostTensor<CDataType>(ck_tile::host_tensor_descriptor(
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{})));
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::HostTensorDescriptor(
{K * experts / ScaleGranularityK, N / ScaleGranularityN}, {N / ScaleGranularityN, 1}));
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k_tensor);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
ck_tile::FillUniformDistribution<ScaleType>{0.f, 1.f}(scale_b);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{1.0f, 1.0f}(a_m_k_tensor);
ck_tile::FillUniformDistribution<BDataType>{1.0f, 1.0f}(b_k_n_tensor);
ck_tile::FillUniformDistribution<ScaleType>{1.0f, 1.0f}(scale_b);
}
ck_tile::HostTensor<BDataType> b_shuffle_host(
ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
shuffle_mxfp4_weight<FlatmmConfig, kind>(
b_k_n_tensor.begin(), b_shuffle_host.begin(), experts, N, K);
ck_tile::HostTensor<ScaleType> scale_b_shuffle =
shuffle_mxfp4_scale<FlatmmConfig, kind>(scale_b, experts);
ck_tile::DeviceMem scale_b_shuffle_dev_buf(scale_b_shuffle.get_element_space_size_in_bytes());
std::cout << "moe_flatmm:"
<< "\n num_experts: " << experts << "\n num_tokens: " << num_tokens
<< "\n topk: " << topk << "\n sorted_tile_num: " << sorted_tile_num
<< "\n problem_n: " << N << "\n problem_k: " << K
<< "\n a_m_k: " << a_m_k_tensor.mDesc << "\n b_k_n: " << b_k_n_tensor.mDesc
<< "\n b_shuffle: " << b_shuffle_host.mDesc << "\n c_m_n: " << c_m_n_tensor.mDesc
<< std::endl;
ck_tile::HostTensor<ck_tile::index_t> expert_ids(
ck_tile::HostTensorDescriptor({sorted_tile_num}, {1}));
ck_tile::HostTensor<ck_tile::index_t> sorted_token_ids(
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
ck_tile::HostTensor<AccDataType> expert_weight(
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
ck_tile::HostTensor<ck_tile::index_t> max_token_id(
ck_tile::HostTensorDescriptor({1 + sorted_tile_num}));
ck_tile::HostTensor<AccDataType> expert_bias(ck_tile::HostTensorDescriptor({experts * N}, {1}));
if(init_method == 0)
{
// for verification only, no need to satify weight normalization
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 1.0f}(expert_weight);
ck_tile::FillUniformDistribution<AccDataType>{-1.0f, 1.0f}(expert_bias);
}
else
{
ck_tile::FillUniformDistribution<AccDataType>{1.0f, 1.0f}(expert_weight);
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 0.0f}(expert_bias);
}
max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8};
// int eids[] = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts);
}
int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
// int token_per_tile = num_tokens * topk / valid_tile_num;
int tokenid = 0;
// sorted_token_ids.mData[0] = 0;
for(int i = 0; i < sorted_tile_num * MPerBlock; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile && tokenid < num_tokens * topk)
{
sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = num_tokens;
}
}
ck_tile::DeviceMem a_m_k_dev_buf{a_m_k_tensor.get_element_space_size_in_bytes()};
ck_tile::DeviceMem b_origin_dev_buf{b_k_n_tensor.get_element_space_size_in_bytes()};
ck_tile::DeviceMem b_shuffle_dev_buf{b_shuffle_host.get_element_space_size_in_bytes()};
ck_tile::DeviceMem c_m_n_dev_buf{c_m_n_tensor.get_element_space_size_in_bytes()};
a_m_k_dev_buf.ToDevice(a_m_k_tensor.data());
b_origin_dev_buf.ToDevice(b_k_n_tensor.data());
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
c_m_n_dev_buf.SetZero();
c_m_n_tensor.SetZero();
ck_tile::DeviceMem sorted_token_ids_dev{sorted_token_ids.get_element_space_size_in_bytes()};
ck_tile::DeviceMem expert_ids_dev{expert_ids.get_element_space_size_in_bytes()};
ck_tile::DeviceMem max_token_id_dev{max_token_id.get_element_space_size_in_bytes()};
ck_tile::DeviceMem expert_weight_dev{expert_weight.get_element_space_size_in_bytes()};
ck_tile::DeviceMem expert_bias_dev{expert_bias.get_element_space_size_in_bytes()};
sorted_token_ids_dev.ToDevice(sorted_token_ids.data());
expert_ids_dev.ToDevice(expert_ids.data());
max_token_id_dev.ToDevice(max_token_id.data());
expert_weight_dev.ToDevice(expert_weight.data());
expert_bias_dev.ToDevice(expert_bias.data());
scale_b_shuffle_dev_buf.ToDevice(scale_b_shuffle.data());
const ck_tile::index_t* p_sorted_token_ids_dev =
static_cast<ck_tile::index_t*>(sorted_token_ids_dev.GetDeviceBuffer());
const ck_tile::index_t* p_expert_ids_dev =
static_cast<ck_tile::index_t*>(expert_ids_dev.GetDeviceBuffer());
const ck_tile::index_t* p_max_token_id_dev =
static_cast<ck_tile::index_t*>(max_token_id_dev.GetDeviceBuffer());
const AccDataType* p_sorted_expert_weight_dev =
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
auto scale_b_shuffle_dev_ptr =
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
static_cast<float*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
auto exp_bias_dev_ptr = ck_tile::FlatmmScalePointer<1>{
static_cast<float*>(expert_bias_dev.GetDeviceBuffer()), experts * N};
using MoeFlatmmArgs = ck_tile::MoeFlatmmHostArgs<
ck_tile::FlatmmScalePointer<-1>,
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>,
ck_tile::FlatmmScalePointer<1>>;
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
p_sorted_expert_weight_dev,
p_expert_ids_dev,
p_max_token_id_dev,
a_m_k_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
num_tokens,
experts,
topk,
1, // k_batch
M,
N,
K,
stride_A,
stride_B,
stride_C,
nullptr,
scale_b_shuffle_dev_ptr,
exp_bias_dev_ptr};
invoke_a16w4_moe_gemm<FlatmmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
kind>(warmup, repeat, gemm_desc);
c_m_n_dev_buf.FromDevice(c_m_n_tensor.data());
bool pass{true};
if(arg_parser.get_int("validate"))
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(IsInputGemm ? num_tokens * topk : num_tokens,
outputN,
stride_C,
is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::HostTensor<AccDataType> scale_A(
ck_tile::HostTensorDescriptor({1, K / ScaleGranularityK}, {1, 1}));
// scaleA = 1 has no effect on the result
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(scale_A);
ck_tile::DeviceMem scale_A_dev_buf(scale_A.get_element_space_size_in_bytes());
scale_A_dev_buf.ToDevice(scale_A.data());
// convert scale_b from e8m0 to float
ck_tile::HostTensor<AccDataType> scale_b_float(ck_tile::HostTensorDescriptor(
{K * experts / ScaleGranularityK, N / ScaleGranularityN}, {N / ScaleGranularityN, 1}));
std::copy(scale_b.begin(), scale_b.end(), scale_b_float.begin());
ck_tile::DeviceMem scale_b_float_dev_buf(scale_b_float.get_element_space_size_in_bytes());
scale_b_float_dev_buf.ToDevice(scale_b_float.data());
std::unique_ptr<ck_tile::DeviceMem> c_m_n_ref_buf =
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
c_m_n_ref_buf->SetZero();
ck_tile::reference_moe_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
static_cast<int>(kind),
ck_tile::moe::Swiglu>(
p_sorted_token_ids_dev,
p_expert_ids_dev,
p_max_token_id_dev,
static_cast<const ADataType*>(a_m_k_dev_buf.GetDeviceBuffer()),
static_cast<const BDataType*>(b_origin_dev_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_ref_buf->GetDeviceBuffer()),
p_sorted_expert_weight_dev,
num_tokens,
MPerBlock,
topk,
M,
N,
K,
stride_A,
stride_B,
stride_C,
M,
1,
ScaleGranularityK,
static_cast<float*>(scale_A_dev_buf.GetDeviceBuffer()),
static_cast<float*>(scale_b_float_dev_buf.GetDeviceBuffer()),
static_cast<float*>(expert_bias_dev.GetDeviceBuffer()));
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
const float atol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
pass = ck_tile::check_err(
c_m_n_tensor, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
<< std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}

View File

@@ -0,0 +1,180 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
template <typename PrecActType,
typename PrecWeightType,
typename FlatmmConfig,
bool UsePersistentKernel = false,
typename ALayout,
typename BLayout,
typename CLayout>
int run_mixed_prec_flatmm_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using ADataType = PrecActType;
using BDataType = PrecWeightType;
using CDataType = PrecActType;
using AccDataType = float;
using ScaleType = ck_tile::e8m0_t;
constexpr int DequantGranularityN = 1;
constexpr int DequantGranularityK = 32;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
ck_tile::index_t init_method = arg_parser.get_int("init");
ck_tile::index_t n_warmup = arg_parser.get_int("warmup");
ck_tile::index_t n_repeat = arg_parser.get_int("repeat");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_host(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_origin_host(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_rslt_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::HostTensorDescriptor(
{K / DequantGranularityK, N / DequantGranularityN}, {N / DequantGranularityN, 1}));
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_b);
}
else if(init_method == 1)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
}
ck_tile::HostTensor<BDataType> b_shuffle_host(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffle_host.begin(), N, K);
ck_tile::HostTensor<ScaleType> scale_b_shuffle = preShuffleScale<FlatmmConfig>(scale_b);
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffle.get_element_space_size_in_bytes());
a_dev_buf.ToDevice(a_host.data());
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
c_rslt_host.SetZero();
scale_b_dev_buf.ToDevice(scale_b_shuffle.data());
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK>{
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN};
invoke_mixed_prec_flatmm<FlatmmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
decltype(scale_b_dev_ptr),
UsePersistentKernel>(a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
scale_b_dev_ptr,
n_warmup,
n_repeat);
c_dev_buf.FromDevice(c_rslt_host.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::DeviceMem b_origin_dev_buf(b_origin_host.get_element_space_size_in_bytes());
b_origin_dev_buf.ToDevice(b_origin_host.data());
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::DeviceMem c_gpu_ref_dev_buf(c_gpu_ref_host.get_element_space_size_in_bytes());
ck_tile::HostTensor<AccDataType> scale_A(
ck_tile::HostTensorDescriptor({1, K / DequantGranularityK}, {1, 1}));
// scaleA = 1 has no effect on the result
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(scale_A);
ck_tile::DeviceMem scale_A_dev_buf(scale_A.get_element_space_size_in_bytes());
scale_A_dev_buf.ToDevice(scale_A.data());
// convert scale_b from e8m0 to float
ck_tile::HostTensor<AccDataType> scale_b_float(ck_tile::HostTensorDescriptor(
{K / DequantGranularityK, N / DequantGranularityN}, {N / DequantGranularityN, 1}));
std::copy(scale_b.begin(), scale_b.end(), scale_b_float.begin());
ck_tile::DeviceMem scale_b_float_dev_buf(scale_b_float.get_element_space_size_in_bytes());
scale_b_float_dev_buf.ToDevice(scale_b_float.data());
c_gpu_ref_dev_buf.SetZero();
ck_tile::reference_blockwise_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
static_cast<ADataType*>(a_dev_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_origin_dev_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_gpu_ref_dev_buf.GetDeviceBuffer()),
M,
N,
K,
stride_A,
stride_B,
stride_C,
M,
DequantGranularityN,
DequantGranularityK,
static_cast<float*>(scale_A_dev_buf.GetDeviceBuffer()),
static_cast<float*>(scale_b_float_dev_buf.GetDeviceBuffer()));
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
pass = ck_tile::check_err(
c_rslt_host, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
<< std::endl;
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}

View File

@@ -0,0 +1,473 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <memory>
#include "moe_flatmm.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/moe_flatmm.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/reference/reference_moe_gemm.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename FlatmmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int MaxVecSize = 16 / sizeof(T);
constexpr int KLane = ck_tile::get_warp_size() / FlatmmConfig::N_Warp_Tile;
constexpr int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane);
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / ItemsPerAccess,
ItemsPerAccess});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 1, 3});
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
// gemm1
// operand-A = [num_token, d_model]
// operand-B = [num_expert, hidden, d_model]
// operand-C = [num_token, topk, hidden]
// gemm2
// operand-A = [num_token, topk, hidden]
// operand-B = [num_expert, d_model, hidden]
// operand-C = [num_token, d_model]
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ck_tile::MoeFlatmmKind moe_kind = ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only,
typename CDEElementWise = ck_tile::element_wise::PassThrough,
typename ScaleM,
typename ScaleN>
float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s)
{
using CodegenFlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
FlatmmConfig::TileParitionerGroupNum,
FlatmmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
ALayout,
BLayout,
ELayout,
FlatmmConfig::NumWaveGroups>;
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
FlatmmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
FlatmmConfig::TransposeC,
FlatmmConfig::UseStructuredSparsity,
false, // UsePersistentKernel_
FlatmmConfig::NumWaveGroups,
true>; // Preshuffle_
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
{
static_assert(
FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0,
"requires NRepeat is multiple of 2 for FFN_gemm1_gate_up");
}
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
constexpr int BlockedXDLN_PerWarp = moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up
? 2
: 1; // determined by scale shuffle pattern
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
FlatmmConfig::M_Warp,
FlatmmConfig::N_Warp,
FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups,
false,
1,
FlatmmConfig::TiledMMAPermuteN,
BlockedXDLN_PerWarp>>;
using CodegenFlatmmPipeline =
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using Kernel = ck_tile::
MoeFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue, moe_kind>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2 ? args.NumTokens * args.TopK
: args.NumTokens,
args.K,
args.stride_A,
is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N * args.NumExperts, args.stride_B, is_row_major(BLayout{})));
const int outputN =
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? args.N / 2 : args.N;
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.NumTokens * args.N * sizeof(CDataType), s.stream_id_));
else if(args.k_batch > 1)
hipGetErrorString(
hipMemsetAsync(args.e_ptr,
0,
args.NumTokens * args.TopK * outputN * sizeof(CDataType),
s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
#include "run_moe_flatmm_example.inc"
template <template <typename PreType> typename FlatmmConfig>
int run_moe_flatmm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string prec_type = arg_parser.get_str("prec");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C")
{
const std::string gemm_kind = arg_parser.get_str("gemm_kind");
if(gemm_kind == "gemm1_gate_up")
{
if(prec_type == "fp8")
{
return run_moe_gemm_example_with_layouts<
ck_tile::fp8_t,
FlatmmConfig<ck_tile::fp8_t>,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "bf8")
{
return run_moe_gemm_example_with_layouts<
ck_tile::bf8_t,
FlatmmConfig<ck_tile::bf8_t>,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "bf16")
{
return run_moe_gemm_example_with_layouts<
ck_tile::bfloat16_t,
FlatmmConfig<ck_tile::bfloat16_t>,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "fp16")
{
return run_moe_gemm_example_with_layouts<
ck_tile::half_t,
FlatmmConfig<ck_tile::half_t>,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
}
}
else if(gemm_kind == "gemm1_gate_only")
{
if(prec_type == "fp8")
{
return run_moe_gemm_example_with_layouts<
ck_tile::fp8_t,
FlatmmConfig<ck_tile::fp8_t>,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "bf8")
{
return run_moe_gemm_example_with_layouts<
ck_tile::bf8_t,
FlatmmConfig<ck_tile::bf8_t>,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "bf16")
{
return run_moe_gemm_example_with_layouts<
ck_tile::bfloat16_t,
FlatmmConfig<ck_tile::bfloat16_t>,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "fp16")
{
return run_moe_gemm_example_with_layouts<
ck_tile::half_t,
FlatmmConfig<ck_tile::half_t>,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
}
}
else if(gemm_kind == "gemm2")
{
if(prec_type == "fp8")
{
return run_moe_gemm_example_with_layouts<ck_tile::fp8_t,
FlatmmConfig<ck_tile::fp8_t>,
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "bf8")
{
return run_moe_gemm_example_with_layouts<ck_tile::bf8_t,
FlatmmConfig<ck_tile::bf8_t>,
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "bf16")
{
return run_moe_gemm_example_with_layouts<ck_tile::bfloat16_t,
FlatmmConfig<ck_tile::bfloat16_t>,
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "fp16")
{
return run_moe_gemm_example_with_layouts<ck_tile::half_t,
FlatmmConfig<ck_tile::half_t>,
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
}
}
else
{
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
"[gemm1_gate_only | gemm1_gate_up | gemm2]");
}
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
return -1;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
try
{
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
return !run_moe_flatmm_example<FlatmmConfig16>(argc, argv);
}
else if(warp_tile == 1)
{
return !run_moe_flatmm_example<FlatmmConfig32>(argc, argv);
}
else if(warp_tile == 2)
{
return !run_moe_flatmm_example<FlatmmConfig16_950>(argc, argv);
}
else
{
return !run_moe_flatmm_example<FlatmmConfig32_950>(argc, argv);
}
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,202 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/moe_flatmm.hpp"
template <typename DataType>
struct FlatmmConfig32
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 32;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool TiledMMAPermuteN = false; // disable PermuteN when NWarpTile != 16
};
template <typename DataType>
struct FlatmmConfig32_950 : public FlatmmConfig32<DataType>
{
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 64;
};
// GEMM config with 16x16 warp tile
template <typename DataType>
struct FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 64;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
template <typename DataType>
struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
{
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
static constexpr int kBlockPerCu = 1;
static constexpr int N_Repeat =
N_Tile / FlatmmConfig16<DataType>::N_Warp_Tile / FlatmmConfig16<DataType>::N_Warp;
static constexpr bool TiledMMAPermuteN = false; // N_Repeat % 2 == 0;
};
template <typename ADataType>
struct GemmBasicTypeConfig;
template <>
struct GemmBasicTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
// ToDo: Add more bias config to support different categories of GEMM.
};
template <>
struct GemmBasicTypeConfig<ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmBasicTypeConfig<ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
// ToDo: Add more bias config to support different categories of GEMM.
};
template <>
struct GemmBasicTypeConfig<ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <typename T>
struct is_8bit_type
: std::bool_constant<std::is_same_v<T, ck_tile::fp8_t> || std::is_same_v<T, ck_tile::bf8_t>>
{
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("experts", "8", "Num of experts - 8 by default")
.insert("NumTokens", "128", "M dimensions - 128 by default.")
.insert("TopK", "3", "Top K - 3 by default.")
.insert("N", "4096", "N dimensions - 4096 by default.")
.insert("K", "4096", "K dimensions - 4096 by default.")
.insert("stride_A", "", "Tensor A strides - it is empty by default.")
.insert("stride_B", "", "Tensor B strides - it is empty by default.")
.insert("stride_C", "", "Tensor C strides - it is empty by default.")
.insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "C", "B tensor data layout - Col by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("gemm_kind",
"gemm1_gate_only",
"Gemm kind in FFN network [gemm1_gate_only | gemm1_gate_up | gemm2] - "
"gemm1_gate_only by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert(
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}

View File

@@ -1,175 +1,12 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <type_traits>
#include "ck_tile/utility/json_dump.hpp"
template <typename T>
constexpr const char* DataTypeToString()
{
if constexpr(std::is_same_v<T, ck_tile::half_t>)
{
return "fp16";
}
else if constexpr(std::is_same_v<T, ck_tile::fp8_t>)
{
return "fp8";
}
else if constexpr(std::is_same_v<T, ck_tile::bf8_t>)
{
return "bf8";
}
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
return "bf16";
}
else
{
return "unknown";
}
}
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// mfma_type, 0:32x32, 1:16x16
template <typename FlatmmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
if(ck_tile::is_gfx12_supported())
{
constexpr int divisor = 2;
constexpr int kABK1PerLane = 8;
constexpr int kABK0PerLane = FlatmmConfig::K_Warp_Tile / divisor / kABK1PerLane;
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / FlatmmConfig::K_Warp_Tile,
kABK0PerLane,
divisor,
kABK1PerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
}
else
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
}
else
{
assert(is_wave32() == false);
divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
}
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
FlatmmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
bool persistent,
typename CDEElementWise>
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s);
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::DeviceMem& b_shuffle_dev_buf,
ck_tile::DeviceMem& c_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
ck_tile::FlatmmHostArgs<> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C};
float ave_time = flatmm_calc<FlatmmConfig,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
false,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
return ave_time;
}
template <typename PrecType,
typename FlatmmConfig,
int ScaleGranularityM = -1,
int ScaleGranularityN = -1,
bool UsePersistentKernel = false,
typename ALayout,
typename BLayout,
typename CLayout>
@@ -213,31 +50,32 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::HostTensor<CDataType> c_rslt_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
ck_tile::HostTensor<AccDataType> per_token_scale(ck_tile::HostTensorDescriptor({M}, {1}));
ck_tile::HostTensor<AccDataType> per_channel_scale(ck_tile::HostTensorDescriptor({N}, {1}));
// TODO: add different init types
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
// ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
// ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_host);
ck_tile::FillMonotonicSeq<BDataType>{}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_channel_scale);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
}
else if(init_method == 3)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
}
else if(init_method == 4)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_channel_scale);
}
else
{
@@ -248,52 +86,69 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
ck_tile::DeviceMem per_channel_scale_dev_buf(
per_channel_scale.get_element_space_size_in_bytes());
a_dev_buf.ToDevice(a_host.data());
c_rslt_host.SetZero();
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
// do pre-shuffle
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig>(b_origin_host);
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
if constexpr(FlatmmConfig::TiledMMAPermuteN)
{
return shuffle_b_v1<FlatmmConfig>(b_origin_host);
}
else
{
return shuffle_b<FlatmmConfig>(b_origin_host);
}
}();
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
float ave_time = invoke_flatmm<FlatmmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>()
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
invoke_flatmm<FlatmmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
decltype(per_token_scale_dev_ptr),
decltype(per_channel_scale_dev_ptr),
UsePersistentKernel>(a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
per_token_scale_dev_ptr,
per_channel_scale_dev_ptr,
n_warmup,
n_repeat);
c_dev_buf.FromDevice(c_rslt_host.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
if(ScaleGranularityM != -1 || ScaleGranularityN != -1)
throw std::runtime_error("ScaleAB is not supported for CPU verification!\n");
ck_tile::HostTensor<CDataType> c_ref_host(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_ref_host.SetZero();
@@ -341,13 +196,41 @@ int run_flatmm_example_with_layouts(int argc,
N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
if constexpr(ScaleGranularityM == -1 && ScaleGranularityN == -1)
{
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
}
else
{
ck_tile::reference_blockwise_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
d_A,
d_B,
d_C,
M,
N,
K,
stride_A,
stride_B,
stride_C,
ScaleGranularityM,
ScaleGranularityN,
K,
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer()));
}
ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_dev_buf.GetDeviceBuffer(),
d_C,
@@ -375,22 +258,5 @@ int run_flatmm_example_with_layouts(int argc,
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
}
if(arg_parser.get_int("json") == 1)
{
dump_flatmm_json_results(arg_parser.get_str("jsonfile"),
DataTypeToString<ADataType>(),
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
pass,
ave_time,
tflops,
gb_per_sec);
}
return pass;
}

View File

@@ -0,0 +1,323 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ck_tile::MoeFlatmmKind kind,
typename CDEElementWise = ck_tile::element_wise::PassThrough,
typename MoeHostArgs>
float invoke_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)
{
float ave_time = moe_gemm<FlatmmConfig,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
kind,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
std::string op_name{"Moe Gemm"};
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
sizeof(BDataType) * args.N * args.K +
sizeof(CDataType) * args.M * args.N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
return ave_time;
}
template <typename PrecType,
typename FlatmmConfig,
ck_tile::MoeFlatmmKind kind,
typename ALayout,
typename BLayout,
typename CLayout>
int run_moe_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
};
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
constexpr int ScaleGranularityM = 1;
constexpr int ScaleGranularityN = 1;
const ck_tile::index_t N = arg_parser.get_int("N");
const ck_tile::index_t K = arg_parser.get_int("K");
ck_tile::index_t stride_A = arg_parser.get_int("stride_A");
ck_tile::index_t stride_B = arg_parser.get_int("stride_B");
ck_tile::index_t stride_C = arg_parser.get_int("stride_C");
const ck_tile::index_t num_tokens = arg_parser.get_int("NumTokens");
const ck_tile::index_t topk = arg_parser.get_int("TopK");
const ck_tile::index_t warmup = arg_parser.get_int("warmup");
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
const ck_tile::index_t experts = arg_parser.get_int("experts");
// TODO: replace the magic declaration
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
ck_tile::index_t sorted_tile_num = (num_tokens + MPerBlock - 1) / MPerBlock * MPerBlock * topk;
ck_tile::index_t valid_tile_num = sorted_tile_num;
ck_tile::index_t sorted_size = sorted_tile_num * MPerBlock;
const ck_tile::index_t M = sorted_tile_num * MPerBlock;
const ck_tile::index_t outputN = kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? N / 2 : N;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr bool IsInputGemm = kind != ck_tile::MoeFlatmmKind::kFFN_gemm2;
stride_A = ck_tile::get_default_stride(
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{}));
auto a_m_k_tensor = ck_tile::HostTensor<ADataType>(ck_tile::host_tensor_descriptor(
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout)));
auto b_k_n_tensor = ck_tile::HostTensor<BDataType>(
is_row_major(b_layout)
? ck_tile::host_tensor_descriptor(experts * N, K, stride_B, is_row_major(b_layout))
: ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
auto c_m_n_tensor = ck_tile::HostTensor<CDataType>(ck_tile::host_tensor_descriptor(
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{})));
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k_tensor);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
auto b_shuffle_host = shuffle_b<FlatmmConfig>(b_k_n_tensor);
std::cout << "moe_flatmm:"
<< "\n num_experts: " << experts << "\n num_tokens: " << num_tokens
<< "\n topk: " << topk << "\n sorted_tile_num: " << sorted_tile_num
<< "\n a_m_k: " << a_m_k_tensor.mDesc << "\n b_k_n: " << b_k_n_tensor.mDesc
<< "\n b_shuffle: " << b_shuffle_host.mDesc << "\n c_m_n: " << c_m_n_tensor.mDesc
<< std::endl;
ck_tile::DeviceMem a_m_k_dev_buf{a_m_k_tensor.get_element_space_size_in_bytes()};
ck_tile::DeviceMem b_origin_dev_buf{b_k_n_tensor.get_element_space_size_in_bytes()};
ck_tile::DeviceMem b_shuffle_dev_buf{b_shuffle_host.get_element_space_size_in_bytes()};
ck_tile::DeviceMem c_m_n_dev_buf{c_m_n_tensor.get_element_space_size_in_bytes()};
a_m_k_dev_buf.ToDevice(a_m_k_tensor.data());
b_origin_dev_buf.ToDevice(b_k_n_tensor.data());
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
c_m_n_dev_buf.SetZero();
c_m_n_tensor.SetZero();
const void* p_a = a_m_k_dev_buf.GetDeviceBuffer();
const void* p_b_origin = b_origin_dev_buf.GetDeviceBuffer();
const void* p_b_shuffle = b_shuffle_dev_buf.GetDeviceBuffer();
void* p_c = c_m_n_dev_buf.GetDeviceBuffer();
// TODO: malloc and init sorted tokens and max tokens buffer
ck_tile::HostTensor<ck_tile::index_t> expert_ids(
ck_tile::HostTensorDescriptor({sorted_tile_num}, {1}));
ck_tile::HostTensor<ck_tile::index_t> sorted_token_ids(
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
ck_tile::HostTensor<AccDataType> expert_weight(
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
ck_tile::HostTensor<ck_tile::index_t> max_token_id(
ck_tile::HostTensorDescriptor({1 + sorted_tile_num}));
ck_tile::HostTensor<AccDataType> per_token_scale(
ck_tile::HostTensorDescriptor({IsInputGemm ? num_tokens : M}, {1}));
ck_tile::HostTensor<AccDataType> per_channel_scale(
ck_tile::HostTensorDescriptor({N * experts}, {1}));
ck_tile::FillUniformDistribution<AccDataType>{0.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{0.f, 1.f}(per_channel_scale);
// for verification only, no need to satify weight normalization
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 1.0f}(expert_weight);
ck_tile::DeviceMem sorted_token_ids_dev{sorted_token_ids.get_element_space_size_in_bytes()};
ck_tile::DeviceMem expert_ids_dev{expert_ids.get_element_space_size_in_bytes()};
ck_tile::DeviceMem max_token_id_dev{max_token_id.get_element_space_size_in_bytes()};
ck_tile::DeviceMem expert_weight_dev{expert_weight.get_element_space_size_in_bytes()};
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
ck_tile::DeviceMem per_channel_scale_dev_buf(
per_channel_scale.get_element_space_size_in_bytes());
max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8};
// int eids[] = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts);
}
int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
// int token_per_tile = num_tokens * topk / valid_tile_num;
int tokenid = 0;
// sorted_token_ids.mData[0] = 0;
for(int i = 0; i < sorted_tile_num * MPerBlock; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile && tokenid < num_tokens * topk)
{
sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = num_tokens;
}
}
sorted_token_ids_dev.ToDevice(sorted_token_ids.data());
expert_ids_dev.ToDevice(expert_ids.data());
max_token_id_dev.ToDevice(max_token_id.data());
expert_weight_dev.ToDevice(expert_weight.data());
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
const ck_tile::index_t* p_sorted_token_ids_dev =
static_cast<ck_tile::index_t*>(sorted_token_ids_dev.GetDeviceBuffer());
const ck_tile::index_t* p_expert_ids_dev =
static_cast<ck_tile::index_t*>(expert_ids_dev.GetDeviceBuffer());
const ck_tile::index_t* p_max_token_id_dev =
static_cast<ck_tile::index_t*>(max_token_id_dev.GetDeviceBuffer());
const AccDataType* p_sorted_expert_weight_dev =
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
using MoeFlatmmArgs =
ck_tile::MoeFlatmmHostArgs<ck_tile::FlatmmScalePointer<ScaleGranularityM>,
ck_tile::FlatmmScalePointer<ScaleGranularityN>>;
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
p_sorted_expert_weight_dev,
p_expert_ids_dev,
p_max_token_id_dev,
p_a,
p_b_shuffle,
p_c,
num_tokens,
experts,
topk,
1, // k_batch
M,
N,
K,
stride_A,
stride_B,
stride_C,
per_token_scale_dev_ptr,
per_channel_scale_dev_ptr};
invoke_moe_gemm<FlatmmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
kind>(warmup, repeat, gemm_desc);
c_m_n_dev_buf.FromDevice(c_m_n_tensor.data());
bool pass{true};
if(arg_parser.get_int("validate"))
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(IsInputGemm ? num_tokens * topk : num_tokens,
outputN,
stride_C,
is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
std::unique_ptr<ck_tile::DeviceMem> c_m_n_ref_buf =
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
c_m_n_ref_buf->SetZero();
ck_tile::reference_moe_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
static_cast<int>(kind),
ck_tile::moe::MoeSilu>(
p_sorted_token_ids_dev,
p_expert_ids_dev,
p_max_token_id_dev,
static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b_origin),
static_cast<CDataType*>(c_m_n_ref_buf->GetDeviceBuffer()),
p_sorted_expert_weight_dev,
num_tokens,
MPerBlock,
topk,
M,
N,
K,
stride_A,
stride_B,
stride_C,
1,
1,
K,
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer()));
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1 /*kbatch*/, max_accumulated_value);
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
const float atol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
pass = ck_tile::check_err(
c_m_n_tensor, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
<< std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}

View File

@@ -1242,6 +1242,15 @@ CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
// buffer atomic-add bf16
// TODO: Replace with bf16x2_t, but llvm builins only accept cktile_llvm_bf16x2_t now.
CK_TILE_DEVICE_EXTERN cktile_llvm_bf16x2_t llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
cktile_llvm_bf16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2bf16");
// buffer atomic-add i32
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
int32_t vdata,
@@ -1476,8 +1485,11 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, e8m0_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_int4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
(std::is_same<T, pk_fp4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
"wrong! not implemented");
using rtn_type = thread_buffer<T, N>;
@@ -2201,6 +2213,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
{
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
"wrong! not implemented");
@@ -2294,6 +2307,40 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
});
}
}
else if constexpr(std::is_same<T, bf16_t>::value)
{
if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
bit_cast<cktile_llvm_bf16x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
static_for<0, 2, 1>{}([&](auto i) {
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
src_thread_data.template get_as<cktile_llvm_bf16x2_t>()[i],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + i * sizeof(bf16x2_t),
0);
});
}
else if constexpr(N == 8)
{
static_for<0, 4, 1>{}([&](auto i) {
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
src_thread_data.template get_as<cktile_llvm_bf16x2_t>()[i],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + i * sizeof(bf16x2_t),
0);
});
}
}
else if constexpr(std::is_same<T, int32_t>::value)
{
if constexpr(N == 1)
@@ -2809,8 +2856,10 @@ __device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::bf16_t>)
{
typedef __attribute__((__vector_size__(4 * sizeof(__bf16)))) __bf16 llvm_bf16x4_t;
auto lds_ptr = reinterpret_cast<__LDS_ADDR llvm_bf16x4_t*>(in_ptr_);
// To use llvm builtins.
__attribute__((address_space(3))) cktile_llvm_bf16x4_t* lds_ptr =
reinterpret_cast<__attribute__((address_space(3))) cktile_llvm_bf16x4_t*>(
reinterpret_cast<uintptr_t>(in_ptr));
return bit_cast<thread_buffer<T, N>>(__builtin_amdgcn_ds_read_tr16_b64_v4bf16(lds_ptr));
}
else if constexpr(std::is_same_v<remove_cvref_t<T>, ck_tile::fp8_t> ||

View File

@@ -1110,6 +1110,15 @@ CK_TILE_DEVICE_EXTERN fp16x2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2(
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16");
// buffer atomic-add bf16
// TODO: Replace with bf16x2_t, but llvm builins only accept cktile_llvm_bf16x2_t now.
CK_TILE_DEVICE_EXTERN cktile_llvm_bf16x2_t llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
cktile_llvm_bf16x2_t vdata,
int32x4_t rsrc,
index_t voffset,
index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2bf16");
// buffer atomic-add i32
CK_TILE_DEVICE_EXTERN int32_t llvm_amdgcn_raw_buffer_atomic_add_i32(
int32_t vdata,
@@ -1344,8 +1353,11 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, e8m0_bexp_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_int4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)),
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
(std::is_same<T, pk_fp4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
"wrong! not implemented");
using rtn_type = thread_buffer<T, N>;
@@ -1984,6 +1996,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
{
static_assert((std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)),
"wrong! not implemented");
@@ -2077,6 +2090,40 @@ CK_TILE_DEVICE void amd_buffer_atomic_add_impl(const thread_buffer<T, N>& src_th
});
}
}
else if constexpr(std::is_same<T, bf16_t>::value)
{
if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
bit_cast<cktile_llvm_bf16x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
0);
}
else if constexpr(N == 4)
{
static_for<0, 2, 1>{}([&](auto i) {
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
src_thread_data.template get_as<cktile_llvm_bf16x2_t>()[i],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + i * sizeof(bf16x2_t),
0);
});
}
else if constexpr(N == 8)
{
static_for<0, 4, 1>{}([&](auto i) {
llvm_amdgcn_raw_buffer_atomic_add_bf16x2(
src_thread_data.template get_as<cktile_llvm_bf16x2_t>()[i],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + i * sizeof(bf16x2_t),
0);
});
}
}
else if constexpr(std::is_same<T, int32_t>::value)
{
if constexpr(N == 1)

View File

@@ -6,10 +6,6 @@
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#define HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN \
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2f16) && \
__has_builtin(__builtin_amdgcn_global_atomic_fadd_v2bf16)
namespace ck_tile {
template <typename T, typename ComputeType>
@@ -18,6 +14,14 @@ CK_TILE_HOST_DEVICE T add(const T& a, const T& b)
return type_convert<T>(type_convert<ComputeType>(a) + type_convert<ComputeType>(b));
}
CK_TILE_HOST_DEVICE fp16x2_t add_fp16x2_t(const fp16x2_t& a, const fp16x2_t& b)
{
fp16x2_t rtn;
rtn[0] = add<fp16_t, float>(a[0], b[0]);
rtn[1] = add<fp16_t, float>(a[1], b[1]);
return rtn;
}
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b)
{
bf16x2_t rtn;
@@ -36,14 +40,6 @@ CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
return rtn;
}
CK_TILE_HOST_DEVICE fp16x2_t add_f16x2_t(const fp16x2_t& a, const fp16x2_t& b)
{
fp16x2_t rtn;
rtn[0] = add<fp16_t, float>(a[0], b[0]);
rtn[1] = add<fp16_t, float>(a[1], b[1]);
return rtn;
}
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
{
fp8x4_t rtn;
@@ -99,6 +95,37 @@ CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t& a, const bf8x8_t& b)
template <typename X>
CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x);
template <>
CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, const fp16x2_t& x)
{
union U32FP162_ADDR
{
uint32_t* u32_a;
fp16x2_t* fp162_a;
};
union U32FP162
{
uint32_t u32;
fp16x2_t fp162;
};
U32FP162_ADDR dword_addr;
U32FP162 cur_v;
U32FP162 new_;
uint32_t old_v, new_v;
dword_addr.fp162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.fp162 = add_fp16x2_t(cur_v.fp162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
}
template <>
CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
{
@@ -316,44 +343,6 @@ CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
} while(cur_v.u64 != old_v);
}
//
// Atomic add for fp16x2_t
//
template <>
CK_TILE_DEVICE void atomic_add<fp16x2_t>(fp16x2_t* p_dst, fp16x2_t const& x)
{
#if HAS_GLOBAL_ATOMIC_PK_ADD_BUILTIN
__builtin_amdgcn_global_atomic_fadd_v2f16(c_style_pointer_cast<fp16x2_t*>(p_dst), x);
#else
union U32F162_ADDR
{
uint32_t* u32_a;
fp16x2_t* f162_a;
};
union U32F162
{
uint32_t u32;
fp16x2_t f162;
};
U32F162_ADDR dword_addr;
U32F162 cur_v;
U32F162 new_;
uint32_t old_v, new_v;
dword_addr.f162_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.f162 = add_f16x2_t(cur_v.f162, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
#endif
}
template <typename T, index_t N>
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
{
@@ -361,7 +350,6 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
(std::is_same<T, uint32_t>::value && (N == 1)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(std::is_same<T, double>::value && (N == 1 || N == 2)) ||
(std::is_same<T, fp16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
@@ -466,13 +454,6 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
}
}
else if constexpr(std::is_same<T, fp16_t>::value)
{
static_for<0, N / 2, 1>{}([&](auto i) {
atomic_add(c_style_pointer_cast<fp16x2_t*>(p_dst) + i,
x.template get_as<fp16x2_t>()[i]);
});
}
}
template <typename T, index_t N>

View File

@@ -9,10 +9,13 @@
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/print.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace ck_tile {
template <index_t, index_t, index_t>
struct static_for;
template <index_t...>
struct sequence;
@@ -196,23 +199,14 @@ struct sequence
{
return sequence<f(Is)...>{};
}
};
template <index_t... Is>
CK_TILE_HOST_DEVICE static void print(const sequence<Is...>&)
{
printf("sequence<");
if constexpr(sizeof...(Is) > 0)
CK_TILE_HOST_DEVICE static void print()
{
bool first = true;
(([&first](index_t value) {
printf("%s%d", first ? "" : ", ", value);
first = false;
}(Is)),
...);
printf("sequence{size: %d, data: [", size());
((printf("%d ", Is)), ...);
printf("]}");
}
printf(">");
}
};
namespace impl {
template <typename T, T... Ints>

View File

@@ -11,6 +11,7 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/pk_int4.hpp"
#include "ck_tile/core/numeric/e8m0.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
@@ -88,7 +89,13 @@ template <typename T, typename = void>
struct vector_traits
{
using scalar_type =
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>, int8_t, remove_cvref_t<T>>;
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_int4_t>,
int8_t,
std::conditional_t<std::is_same_v<remove_cvref_t<T>, pk_fp4_t> ||
std::is_same_v<remove_cvref_t<T>, e8m0_t>,
uint8_t,
remove_cvref_t<T>>>;
static constexpr index_t vector_size = 1;
};
@@ -96,7 +103,13 @@ struct vector_traits
template <typename T, index_t N>
struct vector_traits<T __attribute__((ext_vector_type(N))), void>
{
using scalar_type = std::conditional_t<std::is_same_v<T, pk_int4_t>, int8_t, T>;
using scalar_type = std::conditional_t<
std::is_same_v<T, pk_int4_t>,
int8_t,
std::conditional_t<std::is_same_v<T, pk_fp4_t> || std::is_same_v<remove_cvref_t<T>, e8m0_t>,
uint8_t,
T>>;
static constexpr index_t vector_size = N;
};
@@ -138,6 +151,9 @@ using bf16x16_t = bfloat16_t __attribute__((ext_vector_type(16)));
using bf16x32_t = bfloat16_t __attribute__((ext_vector_type(32)));
using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64)));
using cktile_llvm_bf16x2_t = __bf16 __attribute__((ext_vector_type(2)));
using cktile_llvm_bf16x4_t = __bf16 __attribute__((ext_vector_type(4)));
// i32
// using int32_t = ...
using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
@@ -237,4 +253,10 @@ using pk_int4x4_t = int8_t __attribute__((ext_vector_type(4)));
using pk_int4x8_t = int8_t __attribute__((ext_vector_type(8)));
using pk_int4x16_t = int8_t __attribute__((ext_vector_type(16)));
using pk_int4x32_t = int8_t __attribute__((ext_vector_type(32)));
using pk_fp4x2_t = uint8_t __attribute((ext_vector_type(2)));
using pk_fp4x4_t = uint8_t __attribute((ext_vector_type(4)));
using pk_fp4x8_t = uint8_t __attribute((ext_vector_type(8)));
using pk_fp4x16_t = uint8_t __attribute((ext_vector_type(16)));
using pk_fp4x32_t = uint8_t __attribute((ext_vector_type(32)));
} // namespace ck_tile

View File

@@ -210,6 +210,28 @@ struct buffer_view<address_space_enum::generic,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: generic, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: Global
@@ -247,7 +269,7 @@ struct buffer_view<address_space_enum::global,
: p_data_{p_data},
buffer_size_{buffer_size / PackedSize},
cached_buf_res_{0},
invalid_element_value_{0}
invalid_element_value_{}
{
}
@@ -631,14 +653,24 @@ struct buffer_view<address_space_enum::global,
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
#if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
||
(std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
#endif
;
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
#elif (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0);
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0)
#if defined(__gfx950__) // only gfx950 support atomic_pk_add_bf16
||
(std::is_same_v<remove_cvref_t<scalar_t>, bfloat16_t> && scalar_per_x_vector % 2 == 0)
#endif
;
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
@@ -735,6 +767,28 @@ struct buffer_view<address_space_enum::global,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Global, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: LDS
@@ -950,34 +1004,51 @@ struct buffer_view<address_space_enum::lds,
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
// clang-format off
static_assert(
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x4_t> && std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x8_t> && std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> && std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x2_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x4_t> &&
std::is_same_v<remove_cvref_t<X>, int8x4_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x8_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8x16_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
// int8 on thread buffer
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 1>>) ||
// ext_vector_type for pk_int4 must use int8_t as type
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> && std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 1>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 2>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 4>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 8>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4x16_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>),
"wrong! not implemented for this combination, please add "
"implementation");
// clang-format on
if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8_t>) ||
@@ -1029,8 +1100,6 @@ struct buffer_view<address_space_enum::lds,
}
else if constexpr((std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, int8x16_t>) ||
(std::is_same_v<remove_cvref_t<T>, int8_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<int8_t, 16>>) ||
(std::is_same_v<remove_cvref_t<T>, pk_int4_t> &&
std::is_same_v<remove_cvref_t<X>, thread_buffer<pk_int4_t, 16>>))
{
@@ -1094,6 +1163,28 @@ struct buffer_view<address_space_enum::lds,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Lds, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
// Address Space: Vgpr
@@ -1247,6 +1338,28 @@ struct buffer_view<address_space_enum::vgpr,
// FIXME: remove
CK_TILE_DEVICE static constexpr bool is_dynamic_buffer() { return true; }
CK_TILE_HOST_DEVICE void print() const
{
printf("buffer_view{");
// AddressSpace
printf("AddressSpace: Vgpr, ");
// p_data_
printf("p_data_: %p, ", static_cast<void*>(const_cast<remove_cvref_t<T>*>(p_data_)));
// buffer_size_
printf("buffer_size_: ");
print(buffer_size_);
printf(", ");
// invalid_element_value_
printf("invalid_element_value_: ");
print(invalid_element_value_);
printf("}");
}
};
template <address_space_enum BufferAddressSpace,
@@ -1272,25 +1385,4 @@ make_buffer_view(T* __restrict__ p, BufferSizeType buffer_size, X invalid_elemen
p, buffer_size, invalid_element_value};
}
// Generalized print function for all buffer_view variants
template <address_space_enum BufferAddressSpace,
typename T,
typename BufferSizeType,
bool InvalidElementUseNumericalZeroValue,
amd_buffer_coherence_enum Coherence>
CK_TILE_HOST_DEVICE void print(const buffer_view<BufferAddressSpace,
T,
BufferSizeType,
InvalidElementUseNumericalZeroValue,
Coherence>& bv)
{
printf("buffer_view{AddressSpace: %s, p_data_: %p, buffer_size_: ",
address_space_to_string(BufferAddressSpace),
static_cast<void*>(const_cast<remove_cvref_t<T>*>(bv.p_data_)));
print(bv.buffer_size_);
printf(", invalid_element_value_: ");
print(bv.invalid_element_value_);
printf("}");
}
} // namespace ck_tile

View File

@@ -404,6 +404,171 @@ struct tile_scatter_gather
});
}
template <typename DistributedTensor,
typename Ys2PageIdxMap,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
const Ys2PageIdxMap& ys_to_page_idx_map,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
const auto idx_gather = ys_to_page_idx_map(idx_ys_start);
const auto page_offset = page_idx_[idx_gather];
// read from bottom tensor
const vector_t vec_value = [&]() {
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
bool_constant<oob_conditional_check>{});
}
else
{
return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
valids_[idx_gather],
bool_constant<oob_conditional_check>{});
}
}();
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j / Traits::PackedSize];
});
// ys_to_page_idx_map handles all offset calculation.
// So ther is no need to move thread coordinate redundantly.
});
});
}
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// Precompute invariant values outside loops
const auto window_origin = lds_tile.get_window_origin();
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
auto lds_window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto lds_bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// Use precomputed window origin
auto lds_bottom_tensor_thread_idx =
window_origin + lds_window_adaptor_thread_coord.get_bottom_index();
// Use precomputed tensor descriptor
const auto lds_coord =
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
// Calculate SMEM address using base pointer
CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
const auto page_offset = page_idx_[idx_gather];
// merge page_offset into bottom_coord
auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
// read from bottom tensor
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
mixed_bottom_thread_coord,
number<0>{},
bool_constant<oob_conditional_check>{});
else
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
mixed_bottom_thread_coord,
number<0>{},
valids_[idx_gather],
bool_constant<oob_conditional_check>{});
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
// lds_diff doesn't need to mask the difference of the gather-dim.
constexpr auto lds_idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
lds_window_adaptor_thread_coord,
lds_bottom_tensor_thread_coord,
lds_idx_diff_ps_ys);
}
});
});
}
// TODO: currently async load only implemented in inline asm
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
@@ -599,6 +764,88 @@ struct tile_scatter_gather
});
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
using Traits = load_store_traits;
// using vector_type_t = typename Traits::vector_type_t;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<0>{}];
const auto page_offset = page_idx_[idx_gather];
// read from distributed tensor
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
Traits::PackedSize;
vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
vec_value,
bool_constant<oob_conditional_check>{});
}
else
{
get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
page_offset,
valids_[idx_gather],
vec_value,
bool_constant<oob_conditional_check>{});
}
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
@@ -810,6 +1057,31 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
}
template <typename NewTensorView_,
typename OldTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
typename StaticPageIndexArray_,
typename StaticValidArray_,
index_t HsGatherDim = 0,
index_t NumCoord = 1>
CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_& new_tensor_view,
const tile_scatter_gather<OldTensorView_,
WindowLengths_,
StaticTileDistribution_,
StaticPageIndexArray_,
StaticValidArray_,
HsGatherDim,
NumCoord>& tile_window)
{
return make_tile_scatter_gather(new_tensor_view,
tile_window.window_lengths_,
tile_window.window_origin_,
tile_window.tile_dstr_,
tile_window.page_idx_,
tile_window.valids_);
}
template <typename TensorView,
typename WindowLengths,
typename StaticTileDistribution,

View File

@@ -840,6 +840,24 @@ make_tile_window_raw(const TensorView_& tensor_view,
return w;
}
template <typename NewTensorView_,
typename OldTensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
index_t NumCoord = 1>
CK_TILE_DEVICE auto
replace_bottom_tensor_view(const NewTensorView_& new_tensor_view,
const tile_window_with_static_distribution<OldTensorView_,
WindowLengths_,
StaticTileDistribution_,
NumCoord>& tile_window)
{
return make_tile_window(new_tensor_view,
tile_window.get_window_lengths(),
tile_window.get_window_origin(),
tile_window.get_tile_distribution());
}
template <typename TensorView_,
typename WindowLengths_,
typename StaticTileDistribution_,
@@ -1001,6 +1019,15 @@ make_tile_window_raw(const tile_window_with_static_lengths<TensorView, WindowLen
return w;
}
template <typename NewTensorView_, typename OldTensorView_, typename WindowLengths_>
CK_TILE_DEVICE auto replace_bottom_tensor_view(
const NewTensorView_& new_tensor_view,
const tile_window_with_static_lengths<OldTensorView_, WindowLengths_>& tile_window)
{
return make_tile_window(
new_tensor_view, tile_window.get_window_lengths(), tile_window.get_window_origin());
}
template <typename TensorView_, typename WindowLengths_>
CK_TILE_DEVICE void move_tile_window(
tile_window_with_static_lengths<TensorView_, WindowLengths_>& window,

View File

@@ -3,8 +3,6 @@
#pragma once
#include <numeric>
#include <functional>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/host/hip_check_error.hpp"
@@ -15,9 +13,15 @@
namespace ck_tile {
template <int MinBlockPerCu, typename Kernel, typename... Args>
#define LOW_CU_PROCESSORS 80
#define HIGH_CU_PROCESSORS 228
#define OPTIMAL_LATENCY_LOW_CU_PROCESSORS 0.005
#define OPTIMAL_LATENCY_HIGH_CU_PROCESSORS 0.0015
#define OPTIMAL_LATENCY_SAFE_MARGIN 0.01
template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... Args>
#if CK_TILE_USE_LAUNCH_BOUNDS
__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
#endif
__global__ void kentry(Args... args)
{
@@ -28,6 +32,11 @@ __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
#endif
}
template <int MaxThreadPerBlock, typename Kernel, typename... Args>
__launch_bounds__(MaxThreadPerBlock) __global__ void kentry2(Args... args)
{
Kernel{}(args...);
}
//
// return a anonymous functor(lambda) to be called later
// the KernelImpl should be a class without non-static data member, or let's say
@@ -35,11 +44,15 @@ __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
//
// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
//
template <int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU, typename KernelImpl, typename... Args>
template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
typename KernelImpl,
typename... Args>
CK_TILE_HOST auto
make_kernel(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
{
const auto kernel = kentry<MinBlockPerCu, KernelImpl, Args...>;
const auto kernel = kentry<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
return [=](const stream_config& s) {
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
};
@@ -55,60 +68,6 @@ CK_TILE_HOST void launch_and_check(const stream_config& sc, Callables&&... calla
}
}
// Measure the preprocess time during the cold iterations
template <typename TimerType, typename PreprocessFunc>
CK_TILE_HOST double
preprocess_profiling_impl(TimerType timer, const stream_config& s, PreprocessFunc preprocess)
{
timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++)
{
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
{
preprocess();
}
}
timer.stop(s.stream_id_);
return timer.duration() / s.nrepeat_;
}
template <typename TimerType, typename CallablesFunc, typename PreprocessFunc = std::nullptr_t>
CK_TILE_HOST double timing_loop_impl(TimerType timer,
const stream_config& s,
CallablesFunc&& callables_func,
PreprocessFunc preprocess = nullptr)
{
for(int i = 0; i < s.cold_niters_; i++)
{
callables_func();
}
// Only profile preprocess if it's provided
auto preprocess_time = 0.0;
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
{
preprocess_time = preprocess_profiling_impl(gpu_timer{}, s, preprocess);
}
int i = 0;
timer.start(s.stream_id_);
while(i < s.nrepeat_)
{
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
{
preprocess();
}
callables_func();
i++;
}
timer.stop(s.stream_id_);
if(!i)
return 0.;
return (timer.duration() / s.nrepeat_) - preprocess_time;
}
// clang-format off
/*
* launch_kernel()
@@ -147,21 +106,37 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables&&... callable
return 0;
}
auto callables_func = [&]() { launch_and_check(s, std::forward<Callables>(callables)...); };
auto time_launches = [&](auto timer) {
// Warmup
for(int i = 0; i < s.cold_niters_; i++)
{
launch_and_check(s, std::forward<Callables>(callables)...);
}
timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++)
{
launch_and_check(s, std::forward<Callables>(callables)...);
}
timer.stop(s.stream_id_);
return timer.duration() / s.nrepeat_;
};
if(s.is_gpu_timer_)
{
return timing_loop_impl(gpu_timer{}, s, callables_func);
return time_launches(gpu_timer{});
}
else
{
return timing_loop_impl(cpu_timer{}, s, callables_func);
return time_launches(cpu_timer{});
}
}
template <typename PreprocessFunc, typename... Callables>
CK_TILE_HOST float
launch_kernel_time_mask(const stream_config& s, PreprocessFunc preprocess, Callables&&... callables)
CK_TILE_HOST float launch_kernel_preprocess(const stream_config& s,
PreprocessFunc preprocess,
Callables&&... callables)
{
static_assert(sizeof...(callables) > 0, "At least one callable is required!");
@@ -172,15 +147,39 @@ launch_kernel_time_mask(const stream_config& s, PreprocessFunc preprocess, Calla
return 0;
}
auto callables_func = [&]() { launch_and_check(s, std::forward<Callables>(callables)...); };
auto time_launches = [&](auto timer) {
// Warmup
for(int i = 0; i < s.cold_niters_; i++)
{
launch_and_check(s, std::forward<Callables>(callables)...);
}
timer.start(s.stream_id_);
for(int i = 0; i < s.nrepeat_; i++)
{
preprocess();
launch_and_check(s, std::forward<Callables>(callables)...);
}
timer.stop(s.stream_id_);
hipDeviceProp_t deviceProps;
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0));
float preprocess_offset = (deviceProps.multiProcessorCount >= HIGH_CU_PROCESSORS)
? OPTIMAL_LATENCY_HIGH_CU_PROCESSORS
: (deviceProps.multiProcessorCount == LOW_CU_PROCESSORS)
? OPTIMAL_LATENCY_LOW_CU_PROCESSORS
: OPTIMAL_LATENCY_SAFE_MARGIN;
return (timer.duration() - preprocess_offset * s.nrepeat_) / s.nrepeat_;
};
if(s.is_gpu_timer_)
{
return timing_loop_impl(gpu_timer{}, s, callables_func, preprocess);
return time_launches(gpu_timer{});
}
else
{
return timing_loop_impl(cpu_timer{}, s, callables_func, preprocess);
return time_launches(cpu_timer{});
}
}
} // namespace ck_tile

View File

@@ -11,196 +11,6 @@
namespace ck_tile {
template <typename ADataType,
typename QDataType,
typename BDataType,
typename AccDataType,
typename CDataType,
uint32_t QuantGroupSize,
bool aquant,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_gemm_quant(const HostTensor<ADataType>& a_m_k,
const HostTensor<QDataType>& q,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
AccDataType v_acc = 0, v_block_acc = 0;
static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
std::is_same_v<ADataType, bf8_t>);
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
std::is_same_v<BDataType, pk_int4_t>);
static_assert(std::is_same_v<AccDataType, float>);
static_assert(std::is_same_v<CDataType, float> ||
std::is_same_v<CDataType, ck_tile::half_t>);
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, fp8_t>)
{
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
}
else
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
}
v_block_acc += v_a * v_b;
// Apply group dequant scale
if((k + 1) % QuantGroupSize == 0)
{
float scale = 0.f;
index_t outer_dim = (aquant) ? m : k / QuantGroupSize;
index_t inner_dim = (aquant) ? k / QuantGroupSize : n;
if constexpr(std::is_same_v<QDataType, float>)
{
scale = q(outer_dim, inner_dim);
}
else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
{
scale = fp8_to_float_raw(q(outer_dim, inner_dim));
}
else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
{
scale = bf8_to_float_raw(q(outer_dim, inner_dim));
}
else
{
static_assert(false, "Unexpected Q datatype.");
}
v_block_acc *= scale;
v_acc += v_block_acc;
v_block_acc = 0;
}
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
};
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
std::cout << std::endl;
}
template <typename ADataType,
typename AQDataType,
typename BDataType,
typename BQDataType,
typename AccDataType,
typename CDataType,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor<ADataType>& a_m_k,
const HostTensor<AQDataType>& aq_m_1,
const HostTensor<BDataType>& b_k_n,
const HostTensor<BQDataType>& bq_1_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
static_assert(std::is_same_v<AccDataType, float>);
static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
// Init accumulator
AccDataType v_acc = 0;
// Get row scale for A and column scale for B
float a_scale = aq_m_1(m, 0);
float b_scale = bq_1_n(0, n);
// Compute the dot product
for(std::size_t k = 0; k < K; ++k)
{
AccDataType v_a;
AccDataType v_b;
// Process A data
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
}
// Process B data
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, fp8_t>)
{
v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
}
else
{
v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
}
v_acc += v_a * v_b;
}
v_acc = v_acc * a_scale * b_scale;
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
};
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
std::cout << std::endl;
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
@@ -359,6 +169,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
else
v_a = fp32_val.lo;
}
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
@@ -371,6 +189,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
@@ -385,6 +211,121 @@ __global__ void naive_gemm_kernel(ADataType* A,
}
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
__global__ void blockwise_gemm_kernel(ADataType* A,
BDataType* B,
CDataType* C,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t strideA,
ck_tile::index_t strideB,
ck_tile::index_t strideC,
ck_tile::index_t scale_granularity_m,
ck_tile::index_t scale_granularity_n,
ck_tile::index_t scale_granularity_k,
float* scale_A_ptr,
float* scale_B_ptr)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int row = idx / N; // Compute row index
int col = idx % N; // Compute column index
if(row < M && col < N)
{
AccDataType acc = 0.0, acc_temp = 0.0;
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
float scale_A = 0;
float scale_B = 0;
for(int k = 0; k < K; ++k)
{
if(k % scale_granularity_k == 0)
{
// update acc
acc += acc_temp * scale_A * scale_B;
acc_temp = 0.0;
// update scale factors
scale_A = scale_A_ptr[(row / scale_granularity_m) +
(k / scale_granularity_k) * scale_A_stride];
scale_B = scale_B_ptr[(col / scale_granularity_n) +
(k / scale_granularity_k) * scale_B_stride];
}
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
// Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? row * strideA + k
: k * strideA + row;
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k
: k * strideB + col;
AccDataType v_a;
AccDataType v_b;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
}
else
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
}
acc_temp += v_a * v_b;
}
// final accumulation
acc += acc_temp * scale_A * scale_B;
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? row * strideC + col
: col * strideC + row;
C[c_index] = ck_tile::type_convert<CDataType>(acc);
}
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
@@ -413,6 +354,51 @@ void reference_gemm_gpu(ADataType* a_ptr,
return;
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC>
void reference_blockwise_gemm_gpu(ADataType* a_ptr,
BDataType* b_ptr,
CDataType* c_ptr,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c,
index_t scale_granularity_m,
index_t scale_granularity_n,
index_t scale_granularity_k,
float* scale_A_ptr,
float* scale_B_ptr)
{
int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
blockwise_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
stride_a,
stride_b,
stride_c,
scale_granularity_m,
scale_granularity_n,
scale_granularity_k,
scale_A_ptr,
scale_B_ptr);
return;
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
@@ -450,4 +436,5 @@ void reference_batched_gemm_gpu(ADataType* a_ptr,
return;
}
} // namespace ck_tile

View File

@@ -0,0 +1,315 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
typename ActivationOp = identity>
__global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
const ck_tile::index_t* p_sorted_expert_ids_,
const ck_tile::index_t* p_max_token_id_,
const ADataType* A,
const BDataType* B,
CDataType* C,
const AccDataType* expert_weight_ptr,
ck_tile::index_t Num_tokens,
ck_tile::index_t TokensPerBlock,
ck_tile::index_t TopK,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t strideA,
ck_tile::index_t strideB,
ck_tile::index_t strideC,
index_t scale_granularity_m,
index_t scale_granularity_n,
index_t scale_granularity_k,
float* scale_A_ptr,
float* scale_B_ptr,
float* expert_bias_ptr)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
int row = idx / problem_N; // Compute row index
int col = idx % problem_N; // Compute column index
index_t gather_token_id = 0;
index_t scatter_token_id = 0;
index_t expert_id = 0;
if(row < p_max_token_id_[0])
{
expert_id = p_sorted_expert_ids_[row / TokensPerBlock];
gather_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
scatter_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
if(gather_token_id >= Num_tokens)
{
return;
}
if(MoeGemmKind == 2)
{
gather_token_id = gather_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
}
else
{
scatter_token_id = scatter_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
}
}
else
{
return;
}
if(row < M)
{
AccDataType acc = 0.0;
AccDataType acc_up = 0.0;
AccDataType acc_temp = 0.0;
AccDataType acc_up_temp = 0.0;
float scale_A = 0;
float scale_B = 0;
float scale_B_up = 0;
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
index_t scale_B_expert_stride = scale_B_stride * K / scale_granularity_k;
for(int k = 0; k < K; ++k)
{
if(k % scale_granularity_k == 0)
{
// update acc
acc += acc_temp * scale_A * scale_B;
acc_up += acc_up_temp * scale_A * scale_B_up;
// reset acc temp
acc_temp = 0.0;
acc_up_temp = 0.0;
// update scale factors
scale_A = scale_A_ptr[(gather_token_id / scale_granularity_m) +
(k / scale_granularity_k) * scale_A_stride];
scale_B =
scale_B_ptr[expert_id * scale_B_expert_stride + col / scale_granularity_n +
(k / scale_granularity_k) * scale_B_stride];
if constexpr(MoeGemmKind == 1)
scale_B_up = scale_B_ptr[expert_id * scale_B_expert_stride +
(col + problem_N) / scale_granularity_n +
(k / scale_granularity_k) * scale_B_stride];
}
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
// Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? gather_token_id * strideA + k
: k * strideA + gather_token_id;
long b_index =
long(expert_id) * N * K +
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) ? col * strideB + k
: k * strideB + col);
long b_index_up;
if constexpr(MoeGemmKind == 1)
b_index_up = long(expert_id) * N * K +
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? (col + problem_N) * strideB + k
: k * strideB + col + problem_N);
AccDataType v_a;
AccDataType v_b;
AccDataType v_b_up;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
if constexpr(MoeGemmKind == 1)
{
const fp32x2_t fp32_val_up =
pk_int4_t_to_fp32x2_t(B[b_index_up / packed_size_b]);
if(k % 2 == 1)
v_b_up = fp32_val_up.hi;
else
v_b_up = fp32_val_up.lo;
}
}
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
if constexpr(MoeGemmKind == 1)
{
const fp32x2_t fp32_val_up = pk_fp4_to_fp32x2(B[b_index_up / packed_size_b], 1.0f);
if(k % 2 == 1)
v_b_up = fp32_val_up.hi;
else
v_b_up = fp32_val_up.lo;
}
}
else
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
if constexpr(MoeGemmKind == 1)
v_b_up = ck_tile::type_convert<AccDataType>(B[b_index_up]);
}
acc_temp += v_a * v_b;
if constexpr(MoeGemmKind == 1)
acc_up_temp += v_a * v_b_up;
}
acc += acc_temp * scale_A * scale_B;
acc_up += acc_up_temp * scale_A * scale_B_up;
float bias = 0.f, bias_up = 0.f;
if(expert_bias_ptr != nullptr)
{
bias = expert_bias_ptr[expert_id * N + col];
if constexpr(MoeGemmKind == 1)
bias_up = expert_bias_ptr[expert_id * N + col + problem_N];
}
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? scatter_token_id * strideC + col
: col * strideC + scatter_token_id;
if constexpr(MoeGemmKind < 2)
{
C[c_index] = ck_tile::type_convert<CDataType>(
ActivationOp{}(acc + bias, MoeGemmKind == 1 ? acc_up + bias_up : 1));
}
else
{
// moe gemm2 don't use activation.
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * expert_weight_ptr[row]);
using ResV2Type = std::conditional_t<std::is_same_v<CDataType, ck_tile::half_t>,
ck_tile::fp16x2_t,
ck_tile::bf16x2_t>;
ResV2Type add_v{0, 0};
if(c_index % 2)
{
// result is the second value of fp16 pair.
add_v.y = res;
}
else
{
// result is the first value of fp16 pair.
add_v.x = res;
}
// mask last bit to make sure atomicAdd pointer is aligned of DWORD.
atomic_add<ResV2Type>(reinterpret_cast<ResV2Type*>(C + (c_index & 0xffff'fffe)), add_v);
}
}
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
typename ActivationOp = identity>
void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
const index_t* p_sorted_expert_ids_,
const index_t* p_max_token_id_,
const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
const AccDataType* expert_weight_ptr,
index_t Num_tokens,
index_t TokensPerBlock,
index_t TopK,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c,
index_t scale_granularity_m,
index_t scale_granularity_n,
index_t scale_granularity_k,
float* scale_A_ptr,
float* scale_B_ptr,
float* exp_bias = nullptr)
{
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
int totalElements = M * problem_N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
moe_gemm_kernel<ADataType,
BDataType,
AccDataType,
CDataType,
LayoutA,
LayoutB,
LayoutC,
MoeGemmKind,
ActivationOp><<<numBlocks, numThreadsPerBlock>>>(p_sorted_token_ids_,
p_sorted_expert_ids_,
p_max_token_id_,
a_ptr,
b_ptr,
c_ptr,
expert_weight_ptr,
Num_tokens,
TokensPerBlock,
TopK,
M,
N,
K,
stride_a,
stride_b,
stride_c,
scale_granularity_m,
scale_granularity_n,
scale_granularity_k,
scale_A_ptr,
scale_B_ptr,
exp_bias);
return;
}
} // namespace ck_tile

View File

@@ -883,7 +883,14 @@ struct Sigmoid
std::is_same_v<T, int32_t>,
"Data type is not supported by this operation!");
constexpr T one = type_convert<T>(1);
y = one / (one + ck_tile::exp(-x));
if constexpr(std::is_same_v<T, float>)
{
y = x * __builtin_amdgcn_rcpf(one + ck_tile::exp(-x));
}
else
{
y = x * (one / (one + ck_tile::exp(-x)));
}
};
};

View File

@@ -48,7 +48,8 @@ template <typename ADataType_,
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
bool TiledMMAPermuteN_ = false>
bool TiledMMAPermuteN_ = false,
index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp
struct CShuffleEpilogueProblem
{
using ADataType = remove_cvref_t<ADataType_>;
@@ -71,6 +72,7 @@ struct CShuffleEpilogueProblem
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -108,6 +110,7 @@ struct CShuffleEpilogue
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
@@ -213,7 +216,8 @@ struct CShuffleEpilogue
}
}();
static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple);
static constexpr index_t NumNXdlPerWavePerShuffle =
max(BlockedXDLN_PerWarp, std::get<1>(shuffle_tile_tuple));
static constexpr auto MNPerIterationShuffle = [] {
constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
@@ -266,14 +270,31 @@ struct CShuffleEpilogue
CK_TILE_DEVICE static constexpr auto MakeLdsDistributionEncode()
{
constexpr auto block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto block_outer_dstr_encoding = [] {
if constexpr(BlockedXDLN_PerWarp == 1)
{
return tile_distribution_encoding<sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<NumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
}
else
{
constexpr int RakedXDLN_PerWarp = NumNXdlPerWavePerShuffle / BlockedXDLN_PerWarp;
// BlockedLayout
return tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
sequence<RakedXDLN_PerWarp, NWave, BlockedXDLN_PerWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{};
}
}();
constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});

View File

@@ -10,8 +10,11 @@
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v0.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -113,6 +113,7 @@ struct BlockFlatmmASmemBSmemCRegV1
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
__builtin_amdgcn_sched_barrier(0x7F6);
});
});
});

498
include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Normal file → Executable file
View File

@@ -11,23 +11,138 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
struct FlatmmProblem
{
CK_TILE_HOST FlatmmProblem() = default;
CK_TILE_HOST FlatmmProblem(
index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
: M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
{
}
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
template <int SharedGranularityMN, int SharedGranularityK = 0>
struct FlatmmScalePointer
{
static constexpr int GranularityMN = SharedGranularityMN;
static constexpr int GranularityK = SharedGranularityK;
const float* ptr;
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t length_)
: ptr(ptr_)
{
}
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
{
FlatmmScalePointer ret;
if constexpr(GranularityMN == 0)
{
ret.ptr = ptr + offset / GranularityK;
}
else
{
ret.ptr = ptr + offset / GranularityMN / GranularityK;
}
return ret;
}
CK_TILE_HOST_DEVICE float operator[](index_t i) const = delete;
};
template <int SharedGranularityMN>
struct FlatmmScalePointer<SharedGranularityMN, 0>
{
static constexpr int GranularityMN = SharedGranularityMN;
static constexpr int GranularityK = 0;
static_assert(GranularityMN != 0);
const float* ptr;
index_t length;
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_), length(1) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t length_)
: ptr(ptr_), length(length_)
{
}
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
{
FlatmmScalePointer ret;
if constexpr(GranularityMN == 1)
{
ret.ptr = ptr + offset;
ret.length = length - offset;
}
else
{
ret.ptr = ptr + offset / GranularityMN;
ret.length = length - offset / GranularityMN;
}
return ret;
}
CK_TILE_HOST_DEVICE float operator[](index_t i) const
{
// with additional oob check
if constexpr(GranularityMN == 1)
return i < length ? ptr[i] : 0;
else
return i / GranularityMN < length ? ptr[i / GranularityMN] : 0;
}
};
// shared granularityMN = -1 means no scale
template <>
struct FlatmmScalePointer<-1, 0>
{
static constexpr int GranularityMN = -1;
static constexpr int GranularityK = 0;
const float* ptr = nullptr;
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, index_t) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
{
return FlatmmScalePointer{};
}
CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const
{
return 1; // alway return 1, it doesn't change the result
}
};
template <index_t NumDTensor = 0>
struct FlatmmHostArgs
struct BaseFlatmmHostArgs
{
CK_TILE_HOST FlatmmHostArgs() = default;
CK_TILE_HOST FlatmmHostArgs(const void* a_ptr_,
const void* b_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
CK_TILE_HOST BaseFlatmmHostArgs() = default;
CK_TILE_HOST BaseFlatmmHostArgs(const void* a_ptr_,
const void* b_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* e_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_E_)
: a_ptr(a_ptr_),
b_ptr(b_ptr_),
ds_ptr(ds_ptr_),
@@ -65,8 +180,51 @@ struct FlatmmHostArgs
index_t k_batch;
};
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
struct ScaleFlatmmHostArgs : public BaseFlatmmHostArgs<>
{
CK_TILE_HOST ScaleFlatmmHostArgs() = default;
CK_TILE_HOST ScaleFlatmmHostArgs(const void* a_ptr_,
const void* b_shuffle_ptr_,
const std::array<const void*, NumDTensor>& ds_ptr_,
void* c_ptr_,
index_t k_batch_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
const std::array<index_t, NumDTensor>& stride_Ds_,
index_t stride_C_,
ScaleM scale_m_ = nullptr,
ScaleN scale_n_ = nullptr)
: BaseFlatmmHostArgs(a_ptr_,
b_shuffle_ptr_,
ds_ptr_,
c_ptr_,
k_batch_,
M_,
N_,
K_,
stride_A_,
stride_B_,
stride_Ds_,
stride_C_),
scale_m(scale_m_),
scale_n(scale_n_)
{
}
ScaleM scale_m = nullptr;
ScaleN scale_n = nullptr;
};
template <index_t NumDTensor = 0>
template <int NumberTensor = 0>
using FlatmmHostArgs =
ScaleFlatmmHostArgs<FlatmmScalePointer<-1>, FlatmmScalePointer<-1>, NumberTensor>;
template <class ScaleM, class ScaleN, index_t NumDTensor = 0>
struct FlatmmKernelArgs
{
const void* a_ptr;
@@ -82,6 +240,8 @@ struct FlatmmKernelArgs
std::array<index_t, NumDTensor> stride_Ds;
index_t stride_E;
index_t k_batch;
ScaleM scale_m_ptr = nullptr;
ScaleN scale_n_ptr = nullptr;
};
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
@@ -91,13 +251,14 @@ struct FlatmmKernel
using FlatmmPipeline = remove_cvref_t<FlatmmPipeline_>;
using BlockGemmShape =
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
@@ -113,7 +274,7 @@ struct FlatmmKernel
static_assert(DsLayout::size() == DsDataType::size(),
"The size of DsLayout and DsDataType should be the same");
using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
// using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
@@ -124,40 +285,87 @@ struct FlatmmKernel
CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
assert(!UsePersistentKernel);
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
}
CK_TILE_HOST static constexpr auto BlockSize()
template <class ScaleM, class ScaleN>
CK_TILE_HOST static constexpr auto
GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
{
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
if constexpr(UsePersistentKernel)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
constexpr int block_size = FlatmmKernel::BlockSize().x;
int dync_smem_size = 0;
int maxActiveBlocksPerCU = 0;
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry2<block_size,
FlatmmKernel,
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
block_size,
dync_smem_size);
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
// << ", persistent_block_size: " << persistent_block_size
// << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
assert(kargs.k_batch == 1);
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
}
else
{
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
}
}
CK_TILE_HOST static constexpr KernelArgs
MakeKernelArgs(const FlatmmHostArgs<NumDTensor>& hostArgs)
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
template <class ScaleM, class ScaleN>
CK_TILE_HOST static constexpr FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>
MakeKernelArgs(const ScaleFlatmmHostArgs<ScaleM, ScaleN, DsDataType::size()>& hostArgs)
{
return KernelArgs{hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.ds_ptr,
hostArgs.e_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_Ds,
hostArgs.stride_E,
hostArgs.k_batch};
return {hostArgs.a_ptr,
hostArgs.b_ptr,
hostArgs.ds_ptr,
hostArgs.e_ptr,
hostArgs.M,
hostArgs.N,
hostArgs.K,
hostArgs.stride_A,
hostArgs.stride_B,
hostArgs.stride_Ds,
hostArgs.stride_E,
hostArgs.k_batch,
hostArgs.scale_m,
hostArgs.scale_n};
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPingSize()
{
return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPongSize()
{
return FlatmmPipeline::GetSmemSize();
}
struct SplitKBatchOffset
{
template <class KernelArgs>
__device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
{
constexpr auto N1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<1>{});
constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
const index_t K_t = kargs.k_batch * K1;
const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
@@ -173,11 +381,11 @@ struct FlatmmKernel
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * KRead * kargs.stride_B;
b_k_split_offset = k_id * KRead * kargs.stride_B * N1;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = k_id * KRead;
b_k_split_offset = k_id * KRead * N1;
}
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
@@ -195,6 +403,7 @@ struct FlatmmKernel
index_t splitted_k;
};
template <class KernelArgs>
CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
{
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
@@ -206,6 +415,14 @@ struct FlatmmKernel
return false;
}
}
if constexpr(UsePersistentKernel)
{
if(kargs.k_batch != 1)
{
std::cerr << "Persistent mode doesn't support Kbatch >1 !" << std::endl;
return false;
}
}
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
@@ -340,7 +557,7 @@ struct FlatmmKernel
return DTesnorIsValid;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
@@ -370,9 +587,9 @@ struct FlatmmKernel
}
}();
index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.splitted_k /
BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
index_t kFlatK =
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
@@ -411,7 +628,7 @@ struct FlatmmKernel
const auto& e_tensor_view = [&]() {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_E, 1),
@@ -420,7 +637,7 @@ struct FlatmmKernel
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.N, kargs.M),
make_tuple(kargs.stride_E, 1),
@@ -429,7 +646,45 @@ struct FlatmmKernel
}
}();
return make_tuple(a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view);
constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
: 1; // per-token scale
auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
: 1; // per-channel scale
static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1,
"only support per-tensor or per-row scaling");
static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1,
"only support per-tensor or per-column scaling");
const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_m_ptr.ptr,
make_tuple(
kargs.M / ScaleGranularityM,
ScaleGranularityKA == 0 ? 1 : splitk_batch_offset.splitted_k / ScaleGranularityKA),
make_tuple(scale_stride_m, 0),
number<ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1>{},
number<1>{});
const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
kargs.scale_n_ptr.ptr,
make_tuple(
ScaleGranularityKB == 0 ? 1 : splitk_batch_offset.splitted_k / ScaleGranularityKB,
kargs.N / ScaleGranularityN),
make_tuple(0, scale_stride_n),
number<ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1>{},
number<1>{});
return make_tuple(a_tensor_view,
b_flat_tensor_view,
ds_tensor_view,
e_tensor_view,
scale_m_view,
scale_n_view);
}
template <typename TensorView>
@@ -495,7 +750,12 @@ struct FlatmmKernel
}
}();
return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view);
return make_tuple(a_pad_view,
b_flat_tensor_view,
ds_pad_view,
e_pad_view,
views.at(number<4>{}),
views.at(number<5>{}));
}
template <typename PadView>
@@ -555,19 +815,42 @@ struct FlatmmKernel
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
constexpr int ScaleGranularityKA = 0; // decltype(kargs.scale_m_ptr)::GranularityK;
constexpr int ScaleGranularityKB = 0; // decltype(kargs.scale_n_ptr)::GranularityK;
auto scale_m_window = make_tile_window(
views.at(number<4>{}),
make_tuple(number<TilePartitioner::MPerBlock>{},
number<ScaleGranularityKA == 0 ? TilePartitioner::NPerBlock
: TilePartitioner::KPerBlock>{}),
{i_m, 0});
auto scale_n_window = make_tile_window(
views.at(number<5>{}),
make_tuple(number<ScaleGranularityKB == 0 ? TilePartitioner::MPerBlock
: TilePartitioner::KPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{0, i_n});
return make_tuple(a_block_window,
b_flat_block_window,
ds_block_window,
e_block_window,
scale_m_window,
scale_n_window);
}
template <bool UseDefaultScheduler = true>
CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
CK_TILE_DEVICE static void
RunFlatmm(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_ping,
void* smem_ptr_pong,
const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
@@ -583,50 +866,77 @@ struct FlatmmKernel
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = FlatmmPipeline{}.template operator()(
a_block_window, b_flat_block_window, num_loop, smem_ptr);
if(UseDefaultScheduler || (get_warp_id() == 0))
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
auto scale_m_window = gemm_tile_windows.at(number<4>{});
auto scale_n_window = gemm_tile_windows.at(number<5>{});
// Run Epilogue Pipeline
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
{
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window,
c_block_tile,
d_block_window,
smem_ptr_ping,
scale_m_window,
scale_n_window);
}
else if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window, c_block_tile, d_block_window, smem_ptr);
c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
}
}
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
template <class ScaleM, class ScaleN>
CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
int partition_idx = blockIdx.x) const
{
const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_flat_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
do
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
RunFlatmm<scheduler_type>(a_ptr,
b_flat_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
const auto [iM, iN] =
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_flat_ptr =
static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
// allocate LDS
__shared__ char smem_ptr_ping[GetSmemPingSize()];
__shared__ char smem_ptr_pong[GetSmemPongSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
b_flat_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_ping,
smem_ptr_pong,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
partition_idx += gridDim.x;
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
}
};

View File

@@ -0,0 +1,458 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp"
namespace ck_tile {
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
{
using Underlying = FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using FlatmmPipeline = remove_cvref_t<FlatmmPipeline_>;
using BlockGemmShape =
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>;
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>;
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>;
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
// Below type is actually accumulation data type - the output of block GEMM.
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr int QuantPackedSize = numeric_traits<BDataType>::PackedSize;
static constexpr int N_Pack = 2;
static constexpr index_t NumDTensor = DsDataType::size();
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>();
static constexpr auto I4 = number<4>();
static_assert(DsLayout::size() == DsDataType::size(),
"The size of DsLayout and DsDataType should be the same");
// using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "mixed_prec_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
// clang-format on
}
template <class ScaleM, class ScaleN>
CK_TILE_HOST static constexpr auto
GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
{
if constexpr(UsePersistentKernel)
{
hipDeviceProp_t prop;
int deviceId = 0; // default device
constexpr int block_size = F16xMXF4FlatmmKernel::BlockSize().x;
int dync_smem_size = 0;
int maxActiveBlocksPerCU = 0;
[[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&maxActiveBlocksPerCU,
reinterpret_cast<void*>(
kentry2<block_size,
F16xMXF4FlatmmKernel,
FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
block_size,
dync_smem_size);
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
// << ", persistent_block_size: " << persistent_block_size
// << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
assert(kargs.k_batch == 1);
return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
}
else
{
return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
}
}
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset;
template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
make_tuple(kargs.stride_A, 1),
number<FlatmmPipeline::GetVectorSizeA()>{},
number<1>{});
}
}();
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1);
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
make_tuple(kFlatN, kFlatK),
make_tuple(kFlatK, 1),
number<FlatmmPipeline::GetVectorSizeB()>{},
number<1>{});
}();
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_Ds[i], 1),
number<EpiloguePipeline::GetVectorSizeD(i)>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
static_cast<const DDataType_*>(ds_ptr[i]),
make_tuple(kargs.N, kargs.M),
make_tuple(kargs.stride_Ds[i], 1),
number<EpiloguePipeline::GetVectorSizeD(i)>{},
number<1>{});
}
},
number<NumDTensor>{});
// TODO: enable vector write for C in ColMajor
const auto& e_tensor_view = [&]() {
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_E, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(kargs.N, kargs.M),
make_tuple(kargs.stride_E, 1),
number<1>{},
number<1>{});
}
}();
auto scale_n = kargs.scale_n_ptr;
index_t FlatScaleK =
(kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const e8m0_t*>(scale_n.ptr),
make_tuple(FlatScaleN, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
return make_tuple(
a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view);
}
template <typename TensorView>
CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
{
const auto& a_pad_view = [&]() {
const auto& a_tensor_view = views.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, FlatmmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(a_tensor_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
}();
const auto& b_flat_tensor_view = views.at(I1);
const auto& ds_pad_view = generate_tuple(
[&](auto i) {
const auto& d_tensor_view = views.at(I2);
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(d_tensor_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
sequence<false, FlatmmPipeline::kPadM>{});
}
},
number<NumDTensor>{});
// TODO vector write in for C in ColMajor
const auto& e_pad_view = [&]() {
const auto& e_tensor_view = views.at(I3);
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<false, FlatmmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
sequence<FlatmmPipeline::kPadM, false>{});
}
}();
return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view, views.at(I4));
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
{
const auto& a_pad_view = views.at(I0);
const auto& b_flat_pad_view = views.at(I1);
const auto& ds_pad_view = views.at(I2);
const auto& e_pad_view = views.at(I3);
const auto& a_block_window = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
}
else
{
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::KPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{0, i_m});
}
}();
const auto& b_flat_block_window =
make_tile_window(b_flat_pad_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
const auto ds_block_window = generate_tuple(
[&](auto i) {
using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
}
else
{
return make_tile_window(ds_pad_view[i],
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::MPerBlock>{}),
{i_n, i_m});
}
},
number<NumDTensor>{});
auto e_block_window = make_tile_window(
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
auto scale_block_window =
make_tile_window(views.at(I4),
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / 32>{}),
{i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
return make_tuple(a_block_window,
b_flat_block_window,
ds_block_window,
e_block_window,
scale_block_window);
}
template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
CK_TILE_DEVICE static void
RunFlatmm(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_ping,
void* smem_ptr_pong,
const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_flat_block_window = gemm_tile_windows.at(I1);
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& scale_block_window = gemm_tile_windows.at(I4);
static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
"ScaleM and ScaleN should have the same GranularityK");
constexpr bool DoEpiScale =
(ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
auto a_block_window_with_distr =
ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
a_block_window.get_window_lengths(),
a_block_window.get_window_origin(),
FlatmmPipeline::GetADramTileDistribution());
const auto& c_block_tile = FlatmmPipeline{}(a_block_window_with_distr,
b_flat_block_window,
scale_block_window,
num_loop,
smem_ptr_ping,
smem_ptr_pong);
// Run Epilogue Pipeline
if constexpr(DoEpiScale)
{
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}(c_block_window,
c_block_tile,
d_block_window,
smem_ptr_ping,
kargs.scale_m_ptr + block_idx_m,
kargs.scale_n_ptr + block_idx_n);
}
else if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
}
}
template <class ScaleM, class ScaleN>
CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
int partition_idx = blockIdx.x) const
{
int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
do
{
const auto [iM, iN] =
TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
const SplitKBatchOffset splitk_batch_offset(kargs);
// options
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
splitk_batch_offset.b_k_split_offset / QuantPackedSize;
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
// allocate LDS
__shared__ char smem_ptr_ping[Underlying::GetSmemPingSize()];
__shared__ char smem_ptr_pong[Underlying::GetSmemPongSize()];
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
is_any_of<EDataType, fp16_t, bf16_t>::value))
{
constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
b_flat_ptr,
kargs.ds_ptr,
e_ptr,
smem_ptr_ping,
smem_ptr_pong,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
partition_idx += gridDim.x;
} while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,883 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
namespace ck_tile {
template <typename Problem>
struct BaseFlatmmPipelineAGmemBGmemCRegV0
{
static constexpr index_t PrefetchStages = 2;
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber tail_num)
{
if (TailNumber::Even == tail_num)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Even>{});
}
else if (TailNumber::Odd == tail_num)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Odd>{});
}
// assert(false);
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
// return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
}
};
template <typename Problem, typename PipelinePolicy = UniversalFlatmmPipelineAgBgCrPolicy>
struct FlatmmPipelineAGmemBGmemCRegV0
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockFlatmm =
remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
static constexpr auto config = BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
static constexpr index_t GetVectorSizeA() { return Problem::VectorSizeA; }
static constexpr index_t GetVectorSizeB() { return Problem::VectorSizeB; }
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr index_t kLdsAlignmentInBytes = 16;
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto idxM = I0;
static constexpr auto idxN = I1;
static constexpr auto idxK = I2;
using BlockTile = remove_cvref_t<typename BlockGemmShape::BlockTile>;
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
static constexpr index_t K1 = 16 / sizeof(ADataType);
static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1;
static constexpr index_t ACopyLoadNumPerK = ACopyLoadNum / KIterPerWarp;
static constexpr index_t AcopyPerLoadM = kMPerBlock / ACopyLoadNum;
static constexpr index_t BloadGap = MIterPerWarp / 2;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto warp_m = WarpTile::at(idxM);
static constexpr auto warp_n = WarpTile::at(idxN);
static constexpr auto warp_k = WarpTile::at(idxK);
/*
defined(USING_MFMA_16x16x32) && defined(ENABLE_FP8) // mi300 fp8 16c 0.5*K1
defined(USING_MFMA_32x32x16) && defined(ENABLE_FP8) // mi300 fp8 32c 0.5*K1
defined(USING_MFMA_16x16x16) && defined(ENABLE_FP16) // mi300 fp16 16c 0.5*K1
defined(USING_MFMA_32x32x8) && defined(ENABLE_FP16) // mi300 fp16 32c 0.5*K1
defined(USING_MFMA_16x16x128) && defined(ENABLE_FP8) // mi350 fp8 32c 2*K1
defined(USING_MFMA_32x32x64) && defined(ENABLE_FP8) // mi350 fp8 64c 2*K1
defined(USING_MFMA_16x16x32) && defined(ENABLE_FP16) // mi350 fp16 16c 1*K1
defined(USING_MFMA_32x32x16) && defined(ENABLE_FP16) // mi350 fp16 32c 1*K1
defined(USING_MFMA_16x16x128) && defined(ENABLE_FP4) // mi350 fp4 16c 1*K1
defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
*/
struct MfmaConfig
{
int mfma_per_wg;
int dsread_per_wg;
};
static constexpr MfmaConfig GetMfmaConfig()
{
// K1 per Mfma = 0.5 cases: mfma_per_wg = 2, dsread_per_wg = 1
if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 32 &&
std::is_same_v<ADataType, fp8_t>) ||
(warp_m == 32 && warp_n == 32 && warp_k == 16 &&
std::is_same_v<ADataType, fp8_t>) ||
(warp_m == 16 && warp_n == 16 && warp_k == 16 &&
std::is_same_v<ADataType, fp16_t>) ||
(warp_m == 32 && warp_n == 32 && warp_k == 8 &&
std::is_same_v<ADataType, fp16_t>))
{
return {2, 1};
}
// K1 per Mfma = 2 cases: mfma_per_wg = 1, dsread_per_wg = 2
else if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 128 &&
std::is_same_v<ADataType, fp8_t>) ||
(warp_m == 32 && warp_n == 32 && warp_k == 64 &&
std::is_same_v<ADataType, fp8_t>))
{
return {1, 2};
}
// K1 per Mfma = 1 cases: mfma_per_wg = 1, dsread_per_wg = 1
else if constexpr((warp_m == 16 && warp_n == 16 && warp_k == 32 &&
std::is_same_v<ADataType, fp16_t>) ||
(warp_m == 32 && warp_n == 32 && warp_k == 16 &&
std::is_same_v<ADataType, fp16_t>) ||
(warp_m == 16 && warp_n == 16 && warp_k == 128 /*&&
std::is_same_v<ADataType, fp4_t> */) ||
(warp_m == 32 && warp_n == 32 && warp_k == 64 /*&&
std::is_same_v<ADataType, fp4_t> */))
{
return {1, 1};
}
// Default configuration
else
{
return {1, 1};
}
}
static constexpr auto mfma_config = GetMfmaConfig();
static constexpr auto mfma_per_wg = mfma_config.mfma_per_wg;
static constexpr auto dsread_per_wg = mfma_config.dsread_per_wg;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AGmemBGmemCRegV1",
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', kPadM, kPadN, kPadK));
// clang-format on
}
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
static constexpr bool DoubleSmemBuffer = false;
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return PipelinePolicy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
{
// Keypoint of pipeline optimize is workload balance in time
// instruction schedule example(128X256X256, 1X4, 16X16X128):
// Iter MNK MFMA ds_read ds_write A_load b_load
// -1 M6N0: 57 - 8 - -
// -1 M6N1: 58 1 - - -
// -1 M6N2: 59 - - 7 -
// -1 M6N3: 60 2 - - -
// -1 M7N0: 61 - - - -
// -1 M7N1: 62 3 - - -
// -1 M7N2: 63 - - 8 -
// -1 M7N3: 64 4 - - -
// 0 M0N0K0: 1 - - - -
// 0 M0N1: 2 5 - - 2
// 0 M0N2: 3 - - - -
// 0 M0N3: 4 6 - - -
// 0 M1N0: 5 - - - -
// 0 M1N1: 6 7 - - 4
// 0 M1N2: 7 - - - -
// 0 M1N3: 8 8 - - -
// 0 M2N0: 9 - - - -
// 0 M2N1: 10 9 - - 6
// 0 M2N2: 11 - - - -
// 0 M2N3: 12 10 - - -
// 0 M3N0: 13 - 1 - -
// 0 M3N1: 14 11 - - 8
// 0 M3N2: 15 - - - -
// 0 M3N3: 16 12 - - -
// 0 M4N0: 17 - 2 - -
// 0 M4N1: 18 13 - - -
// 0 M4N2: 19 - - 1 -
// 0 M4N3: 20 14 - - -
// 0 M5N0: 21 - 3 - -
// 0 M5N1: 22 15 - - -
// 0 M5N2: 23 - - 2 -
// 0 M5N3: 24 16 - - -
// 0 M6N0: 25 - 4 - -
// 0 M6N1: 26 17 - - -
// 0 M6N2: 27 - - 3 -
// 0 M6N3: 28 18 - - -
// 0 M7N0: 29 - - - -
// 0 M7N1: 30 19 - - -
// 0 M7N2: 31 - - 4 -
// 0 M7N3: 32 20 - - -
// 0 M0N0K1: 33 - - - -
// 0 M0N1: 34 21 - - 10
// 0 M0N2: 35 - - - -
// 0 M0N3: 36 22 - - -
// 0 M1N0: 37 - - - -
// 0 M1N1: 38 23 - - 12
// 0 M1N2: 39 - - - -
// 0 M1N3: 40 24 - - -
// 0 M2N0: 41 - - - -
// 0 M2N1: 42 25 - - 14
// 0 M2N2: 43 - - - -
// 0 M2N3: 44 26 - - -
// 0 M3N0: 45 - 5 - -
// 0 M3N1: 46 27 - - 16
// 0 M3N2: 47 - - - -
// 0 M3N3: 48 28 - - -
// 0 M4N0: 49 - 6 - -
// 0 M4N1: 50 29 - - -
// 0 M4N2: 51 - - 5 -
// 0 M4N3: 52 30 - - -
// 0 M5N0: 53 - 7 - -
// 0 M5N1: 54 31 - - -
// 0 M5N2: 55 - - 6 -
// 0 M5N3: 56 32 - - -
// 0 M6N0: 57 - 8 - -
// 0 M6N1: 58 1 - - -
// 0 M6N2: 59 - - 7 -
// 0 M6N3: 60 2 - - -
// 0 M7N0: 61 - - - -
// 0 M7N1: 62 3 - - -
// 0 M7N2: 63 - - 8 -
// 0 M7N3: 64 4 - - -
#if 0
constexpr auto dsread_num_perK = dsread_per_wg * MIterPerWarp;
constexpr auto dswrite_num_perK = (dsread_num_perK + MWarp * NWarp - 1) / (MWarp * NWarp);
constexpr auto dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
// index_t dsread_perM[MIterPerWarp];
// index_t dswrite_perM[MIterPerWarp];
index_t dsread_perM[MIterPerWarp];
index_t dswrite_perM[MIterPerWarp];
index_t load_perM[MIterPerWarp];
constexpr int dswrite_inst = dswrite_num_perK;
constexpr int NIter_num = NIterPerWarp*mfma_per_wg;
#pragma unroll
for(int i=0;i<MIterPerWarp;i++)
{
dsread_perM[i] = 2;
if(i==0)
{
dswrite_perM[0] = (dswrite_inst - MIterPerWarp + 2) > 0 ? dswrite_inst - MIterPerWarp + 2 : 0;
}
else if(i==MIterPerWarp-1)
{
dswrite_perM[MIterPerWarp-1] = 0;
}
else
{
dswrite_perM[i] = (i + 2 - dswrite_inst) > 0 ? 1 : 0;
}
}
#pragma unroll
for(int i=0;i<4;i++)
{
load_perM[i] = 2;
}
#pragma unroll
for(int i=4;i<8;i++)
{
load_perM[i] = 1;
}
#pragma unroll
for(int i=0;i<MIterPerWarp;i++)
{
int biger_num = dsread_perM[i] > load_perM[i] ? (dsread_perM[i] > dswrite_perM[i] ? dsread_perM[i] : dswrite_perM[i]) : (load_perM[i] > dswrite_perM[i] ? load_perM[i] : dswrite_perM[i]);
int total_num = dsread_perM[i] + load_perM[i] + dswrite_perM[i];
int gap = (total_num+NIter_num-1)/NIter_num;
index_t inst_order[MIterPerWarp*10];
#pragma unroll
for(int j=0;j<MIterPerWarp*10;j++)
{
inst_order[j] = 0;
}
int index=0;
#pragma unroll
for(int j=0;j<biger_num;j++)
{
if(dswrite_perM[i]>j)
{
inst_order[index] = 1;
index++;
}
if(load_perM[i]>j)
{
inst_order[index] = 2;
index++;
}
if(dsread_perM[i]>j)
{
inst_order[index] = 3;
index++;
}
}
#pragma unroll
for(int j=0;j<NIter_num;j++)
{
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
#pragma unroll
for(int m=0;m<gap;m++)
{
if(m%2==0)
{
if(inst_order[j+m*NIter_num]==1)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if(inst_order[j+m*NIter_num]==2)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if(inst_order[j+m*NIter_num]==3)
{
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
}
else
{
if(inst_order[(m+1)*NIter_num-1-j]==1)
{
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
}
if(inst_order[(m+1)*NIter_num-1-j]==2)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if(inst_order[(m+1)*NIter_num-1-j]==3)
{
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
}
}
}
}
__builtin_amdgcn_sched_barrier(0);
#endif
if constexpr(kMPerBlock == 128 && kNPerBlock == 128 && kKPerBlock == 128)
{
constexpr index_t KPerLoad = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad;
constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp;
constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp;
static_for<0, A_LDS_Read_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
});
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
});
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA
});
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA
});
__builtin_amdgcn_sched_barrier(0);
}
}
CK_TILE_HOST_DEVICE static constexpr auto TailHotLoopScheduler()
{
#if 0
static_for<0, 2, 1>{}([&](auto j) {
ignore = j;
static_for<0, 3, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
static_for<0, 3, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_barrier(0);
#endif
}
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
"wrong!");
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
const index_t iMWarp = get_warp_id() / NWarp;
using CWarpDstr = typename WG::CWarpDstr;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
__builtin_amdgcn_sched_barrier(0);
// A tile in LDS
ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block_ping = make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
auto a_lds_block_pong = make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
// A DRAM tile window for load
auto a_copy_dram_window =
make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
a_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeADramTileDistribution<Problem>());
auto a_copy_lds_window_ping =
make_tile_window(a_lds_block_ping,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
auto a_copy_lds_window_pong =
make_tile_window(a_lds_block_pong,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
PipelinePolicy::template MakeADramTileDistribution<Problem>());
// ping-pong window for A LDS
auto a_warp_window_ping_tmp = make_tile_window(
a_lds_block_ping,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
auto a_warp_window_pong_tmp = make_tile_window(
a_lds_block_pong,
make_tuple(number<WG::kM>{}, number<WG::kK>{}),
{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows_ping;
statically_indexed_array<
statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
MIterPerWarp>
a_warp_windows_pong;
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
move_tile_window(a_warp_windows_ping(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
move_tile_window(a_warp_windows_pong(mIter)(kIter),
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
});
});
// Block GEMM
auto block_flatmm = BlockFlatmm();
// Acc register tile
auto c_block_tile = block_flatmm.MakeCBlockTile();
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
auto b_flat_dram_window = // tile_window_with_static_distribution
make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
// pingpong buffer for B
statically_indexed_array<
statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
NIterPerWarp>
b_flat_dram_windows;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
b_warp_tensor_ping;
statically_indexed_array<
statically_indexed_array<decltype(load_tile(b_flat_dram_window)), KIterPerWarp>,
NIterPerWarp>
b_warp_tensor_pong;
// Prefetch A0
auto a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// prefetch B
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
__builtin_amdgcn_sched_barrier(0);
// Prefetch A1
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
index_t iCounter = (num_loop - 1) / 2;
while(iCounter > 0)
{
// prefetch B(2i+1)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// Prefill A(2i+1)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
// Prefetch A(2i+2)
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
auto a_warp_tensor_ping = load_tile(a_warp_windows_ping(mIter)(kIter));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor_ping, b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
block_sync_lds();
HotLoopScheduler();
//Next K
// prefetch B(2i+2)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// Prefill A(2i+2)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
// Prefetch A(2i+3)
a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
// GEMM 2i+1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
auto a_warp_tensor_pong = load_tile(a_warp_windows_pong(mIter)(kIter));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor_pong, b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
// move B window to next flat K
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
block_sync_lds();
HotLoopScheduler();
iCounter--;
}
// tail
if constexpr(TailNum == TailNumber::Even)
{
// prefetch B(loopK)
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
// Prefill A(loopK)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
// GEMM loopK-1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
auto a_warp_tensor_ping = load_tile(a_warp_windows_ping(mIter)(kIter));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor_ping, b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
block_sync_lds();
TailHotLoopScheduler();
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
auto a_warp_tensor_pong = load_tile(a_warp_windows_pong(mIter)(kIter));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor_pong, b_warp_tensor_pong(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
else if constexpr(TailNum == TailNumber::Odd)
{
// GEMM loopK
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
auto a_warp_tensor_ping = load_tile(a_warp_windows_ping(mIter)(kIter));
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor_ping, b_warp_tensor_ping(nIter)(kIter));
// write C warp tensor into C block tensor
c_block_tile.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
return c_block_tile;
}
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
void* p_smem_ping,
void* p_smem_pong) const
{
return operator()(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_flat_dram_block_window_tmp,
num_loop,
p_smem_ping,
p_smem_pong);
}
};
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"
namespace ck_tile {
@@ -238,22 +239,48 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
{
using TileShape = typename Problem::BlockGemmShape;
#if defined(__gfx11__)
constexpr index_t scale = 4;
#else
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
#endif
if constexpr(TileShape::WarpTile::at(I1) == 32)
{
return TileShape::WarpTile::at(I2) * scale / 2;
return TileShape::WarpTile::at(I2) / 2;
}
else
{
static_assert(TileShape::WarpTile::at(I1) == 16);
return TileShape::WarpTile::at(I2) * scale / 4;
return TileShape::WarpTile::at(I2) / 4;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALDS_WarpTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t KPerXdl = Problem::BlockGemmShape::WarpTile::at(I2);
constexpr int Repeat = TileShape::BlockWarps::at(number<1>{});
constexpr int KLane = get_warp_size() / MPerXdl;
constexpr int KPerThread = KPerXdl / KLane;
constexpr int MaxVecSize = 16 / sizeof(ADataType);
constexpr int KItemsPerLoad = min(MaxVecSize, KPerThread);
constexpr int KFragment = KPerThread / KItemsPerLoad;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<Repeat>,
tuple<sequence<MPerXdl>, sequence<KFragment, KLane, KItemsPerLoad>>,
tuple<sequence<0>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 2>,
sequence<0, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
@@ -307,10 +334,10 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (M2 * K0) == 0)
if constexpr(get_warp_size() % K0 == 0)
{
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
@@ -329,24 +356,55 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
}
else
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M1, M2 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
constexpr index_t KWave = K0 / get_warp_size();
constexpr index_t M0 = BlockSize / get_warp_size() / KWave;
constexpr index_t M1 = MPerBlock / M0;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M0, M1>, sequence<KWave, get_warp_size(), K1>>,
tuple<sequence<1, 2>, sequence<2>>,
tuple<sequence<0, 0>, sequence<1>>,
sequence<1, 2>,
sequence<1, 2>>{});
}
}
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeADramDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
// constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
// constexpr index_t M0 = MPerBlock / (M2 * M1);
// static_assert(M0 * M1 * M2 == MPerBlock,
// "Incorrect M0, M2, M1 configuration! "
// "M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2>,
sequence<1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
@@ -355,15 +413,16 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KBPerLoad = GetKBPerLoad<Problem>();
#if defined(__gfx11__)
constexpr index_t KRepeatInWave = 2;
#else
constexpr index_t KRepeatInWave = 1;
#endif
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
constexpr index_t MaxVecSize = 16 / sizeof(typename Problem::BDataType);
constexpr index_t KItemsPerLoad = min(KBPerLoad, MaxVecSize);
constexpr index_t KFragment = KBPerLoad / KItemsPerLoad;
static_assert(KFragment * KItemsPerLoad == KBPerLoad);
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim./
constexpr index_t KWavePerBlk = 1;
constexpr index_t KRepeat = 1;
static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong");
static_assert(TileShape::BlockWarps::at(number<2>{}) == 1, "Requires K_Warp == 1");
constexpr index_t NBPerLoad = 1;
constexpr index_t NThdPerWave = 1;
@@ -371,15 +430,17 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
constexpr index_t NRepeat = 1;
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat, KRepeatInWave>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KRepeat, KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
sequence<WaveRepeat>, // ?
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
sequence<KFragment, KWavePerBlk, KThdPerWave, KItemsPerLoad>>, // first
// direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<1, 2, 2>>, // which index
tuple<sequence<0, 1, 2>, sequence<1, 2>>, // which direction
tuple<sequence<0, 1, 1>, sequence<2, 2>>, // which index
// <repeat, vec_load>
sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{});
@@ -440,12 +501,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
typename Problem::BDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy<
typename Problem::ADataType,

View File

@@ -0,0 +1,240 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp"
namespace ck_tile {
#define CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE 0
#if defined(__gfx950__)
#define CKTILE_FLATMM_ARCH_SUPPORT_BUFFER_LOAD_LDS_DWORDx4 1
#else
#define CKTILE_FLATMM_ARCH_SUPPORT_BUFFER_LOAD_LDS_DWORDx4 0
#endif
#define CKTILE_FLATMM_USE_BUFFER_LOAD_LDS \
(CKTILE_FLATMM_USE_BUFFER_LOAD_LDS_AS_POSSIBLE && \
CKTILE_FLATMM_ARCH_SUPPORT_BUFFER_LOAD_LDS_DWORDx4)
struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr index_t KBPerLoad = 32;
static constexpr index_t N_Pack = 2; // it's fixed for fp4
static constexpr index_t K_Pack = 2; // it's fixed for fp4
template <typename Problem, typename NativeADramTensorView>
CK_TILE_HOST_DEVICE static constexpr auto
TransformF16xF4_ATensorView(const NativeADramTensorView& a_dram_view)
{
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS
constexpr int DynamicTileOffsetFlag = 0;
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
static_assert(MPerXdl == 16 && NPerXdl == 16);
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
// implement swizzle pattern on global side
// because we can't adjust the ds_write pattern of BUFFER_LOAD_LDS.
auto swizzle_a_dram_view_1 = transform_tensor_view(
a_dram_view,
make_tuple(
// M-dim is not affected by swizzle pattern
make_unmerge_transform(
make_tuple(number<DynamicTileOffsetFlag>{}, number<MPerBlock>{})),
// K-dim is the swizzle dimension
make_unmerge_transform(make_tuple(number<DynamicTileOffsetFlag>{},
number<KPerBlock / KPack>{},
number<KPack>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}));
auto swizzle_a_dram_view_2 = transform_tensor_view(
swizzle_a_dram_view_1,
make_tuple(make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
make_xor_transform(make_tuple(number<MPerBlock>{},
number<ContiguousThreadsCntInDS_READ_16B>{})),
make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
return transform_tensor_view(
swizzle_a_dram_view_2,
make_tuple(
make_merge_transform_v3_division_mod(
make_tuple(number<DynamicTileOffsetFlag>{}, number<MPerBlock>{})),
make_merge_transform_v3_division_mod(make_tuple(number<DynamicTileOffsetFlag>{},
number<KPerBlock / KPack>{},
number<KPack>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#else
return a_dram_view;
#endif
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ReadALdsBlockDescriptor()
{
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
static_assert(MPerXdl == 16 && NPerXdl == 16);
/*reduce transform layers,compare with old ck*/
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
number<ContiguousThreadsCntInDS_READ_16B>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_WriteALdsBlockDescriptor()
{
#if CKTILE_FLATMM_USE_BUFFER_LOAD_LDS
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
return make_naive_tensor_descriptor(make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
make_tuple(number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
#else
return MakeF16xF4_ReadALdsBlockDescriptor<Problem>();
#endif
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeF16xF4_ALDS_TileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
constexpr int Repeat = TileShape::BlockWarps::at(number<1>{});
constexpr int M0 = TileShape::WarpTile::at(I0);
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I1); // 4
constexpr int K2 = TileShape::WarpTile::at(I2) / K_Lane; // 8
constexpr int XDL_PerThreadK = KBPerLoad / K2; // 4
constexpr int K0 = K_Lane; // 4
return make_static_tile_distribution(
tile_distribution_encoding<sequence<Repeat>,
tuple<sequence<M0>, sequence<K0, XDL_PerThreadK, K2>>,
tuple<sequence<0>, sequence<2, 1>>,
tuple<sequence<0>, sequence<0, 0>>,
sequence<2>,
sequence<2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4BFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16");
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<WaveRepeat>, // ?
tuple<sequence<NWavePerBlk, N_Pack>, // second
// direction
sequence<KWavePerBlk, KThdPerWave, KBPerLoad>>, // first direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<0, 1, 2>, sequence<2>>, // which direction
tuple<sequence<0, 0, 0>, sequence<1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeFp4ScaleBFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
constexpr index_t XDLPerBlock = TileShape::kK / TileShape::WarpTile::at(I2);
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
constexpr index_t NWavePerBlk = N_Warp;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<>, // ?
tuple<sequence<NWavePerBlk>, // second direction
sequence<K_Lane, N_Lane, N_Pack * K_Pack>>, // first
// direction
// wave in blk, // thd in wave
// <M, K> // <M, K>
tuple<sequence<1>, sequence<2, 2>>, // which direction
tuple<sequence<0>, sequence<0, 1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<2>>{});
}
};
} // namespace ck_tile

View File

@@ -49,9 +49,6 @@ struct GemmPipelineProblemBase
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
// In the base situation, the Preshuffle setting should be false.
static constexpr bool Preshuffle = false;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
@@ -181,6 +178,150 @@ using GemmPipelineProblem = GemmPipelineProblemBase<ADataType_,
VectorSizeA_,
VectorSizeB_>;
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_>
struct FlatmmPipelineProblem
{
using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename Traits::ALayout>;
using BLayout = remove_cvref_t<typename Traits::BLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr bool TransposeC = Traits::TransposeC;
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadM = Traits::kPadM;
static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK;
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm_problem",
concat('x', VectorLoadSize, kBlockSize),
concat('x', kPadM, kPadN, kPadK),
Scheduler);
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
? pixels_per_thread
: PackedSize * VectorLoadSize / sizeof(ADataType);
}
else
{
return VectorLoadSize / sizeof(ADataType);
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
? pixels_per_thread
: PackedSize * VectorLoadSize / sizeof(BDataType);
}
else
{
return PackedSize * VectorLoadSize / sizeof(BDataType);
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
{
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
constexpr index_t M0 = get_warp_size() / N2;
constexpr index_t M1 = BlockGemmShape::kM / M0;
return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
}
else
{
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = BlockGemmShape::kN / N0;
return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
}
}
static constexpr index_t VectorSizeA = []() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return kPadK ? 1 : GetAlignmentA();
}
else
{
return kPadM ? 1 : GetAlignmentA();
}
}();
static constexpr index_t VectorSizeB = []() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return kPadN ? 1 : GetAlignmentB();
}
else
{
return kPadK ? 1 : GetAlignmentB();
}
}();
static constexpr index_t VectorSizeC = []() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return kPadN ? 1 : GetAlignmentC();
}
else
{
return kPadM ? 1 : GetAlignmentC();
}
}();
};
template <typename ADataType_,
typename BDataType_,
typename CDataType_,

View File

@@ -0,0 +1,10 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/moe_flatmm/kernel/moe_flatmm_kernel.hpp"
#include "ck_tile/ops/moe_flatmm/pipeline/moe_flatmm_pipeline_agmem_bgmem_creg.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff