mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] MX Flatmm Split kernel instances (#3207)
* [CK_TILE] MX Flatmm Split kernel instances * Fix flatmm example compile
This commit is contained in:
@@ -14,7 +14,12 @@ if(has_supported_gpu)
|
||||
add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp)
|
||||
add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp)
|
||||
add_executable(tile_example_grouped_flatmm EXCLUDE_FROM_ALL grouped_flatmm.cpp)
|
||||
add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp) # TODO: 950 only
|
||||
|
||||
include(mxgemm/mx_flatmm_instance.cmake)
|
||||
mx_flatmm_instance_generate(EXAMPLE_MX_FLATMM_FILES)
|
||||
message(STATUS "Generated MX FlatMM kernel files: ${EXAMPLE_MX_FLATMM_FILES}")
|
||||
add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp ${EXAMPLE_MX_FLATMM_FILES})
|
||||
target_include_directories(tile_example_mx_flatmm PRIVATE mxgemm)
|
||||
|
||||
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
|
||||
set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS)
|
||||
|
||||
@@ -29,7 +29,7 @@ static constexpr inline auto is_row_major(Layout layout_)
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
auto flatmm_shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
|
||||
@@ -20,211 +20,6 @@ static constexpr inline auto is_row_major(Layout layout_)
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise>
|
||||
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
persistent,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>;
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
|
||||
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::MXFlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using CodegenMXFlatmmPipeline =
|
||||
ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
|
||||
using Kernel =
|
||||
ck_tile::MXFlatmmKernel<TilePartitioner, CodegenMXFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenMXFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -269,21 +64,66 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
scale_a,
|
||||
scale_b};
|
||||
|
||||
float ave_time = mx_flatmm_calc<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ScaleA,
|
||||
ScaleB,
|
||||
UsePersistentKernel,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
using FlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<FlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, FlatmmShape, Traits>;
|
||||
|
||||
using BaseFlatmmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(k_split);
|
||||
const bool has_hot_loop = BaseFlatmmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseFlatmmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time = BaseFlatmmPipeline::template TailHandler<true>(
|
||||
[&](auto has_hot_loop_, auto tail_num_) {
|
||||
constexpr auto has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_num_v = tail_num_.value;
|
||||
auto invoke_splitk_path = [&](auto split_k_) {
|
||||
return mx_flatmm_calc<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ScaleA,
|
||||
ScaleB,
|
||||
UsePersistentKernel,
|
||||
CDEElementWise,
|
||||
split_k_.value,
|
||||
has_hot_loop_v,
|
||||
tail_num_v>(
|
||||
args,
|
||||
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
};
|
||||
return (args.k_batch == 1) ? invoke_splitk_path(std::false_type{})
|
||||
: invoke_splitk_path(std::true_type{});
|
||||
},
|
||||
has_hot_loop,
|
||||
tail_num);
|
||||
|
||||
constexpr int APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr int BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
@@ -297,8 +137,8 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run MXFP4_Flatmm kernel " //
|
||||
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
|
||||
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
|
||||
<< " M = " << M << " N = " << N << " K = " << K << " StrideA = " << stride_A
|
||||
<< " StrideB = " << stride_B << " StrideC = " << stride_C << " : " << ave_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
@@ -441,21 +281,13 @@ int run_mx_flatmm_example(int argc, char* argv[])
|
||||
if(mx_prec == "fp4xfp4")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::fp16_t,
|
||||
FlatmmConfig,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::fp16_t,
|
||||
FlatmmConfig,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
else
|
||||
{
|
||||
run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::fp16_t,
|
||||
FlatmmConfig,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
throw std::runtime_error("Only non-persistent kernels are supported currently!");
|
||||
}
|
||||
else if(mx_prec == "fp6xfp6")
|
||||
{
|
||||
@@ -487,7 +319,7 @@ int main(int argc, char* argv[])
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_mx_flatmm_example<MXfp4_FlatmmConfig16>(argc, argv);
|
||||
return run_mx_flatmm_example<MXfp4_FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
|
||||
27
example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake
Normal file
27
example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake
Normal file
@@ -0,0 +1,27 @@
|
||||
function(mx_flatmm_instance_generate FILE_LIST)
|
||||
set(FLATMM_CONFIG MXfp4_FlatmmConfig16)
|
||||
set(A_DATA_TYPE FP4)
|
||||
set(B_DATA_TYPE FP4)
|
||||
set(C_DATA_TYPE FP16)
|
||||
set(A_LAYOUT ROW)
|
||||
set(B_LAYOUT COL)
|
||||
set(C_LAYOUT ROW)
|
||||
|
||||
# foreach(PERSISTENT false true)
|
||||
# TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions.
|
||||
foreach(PERSISTENT false)
|
||||
foreach(SPLIT_K false true)
|
||||
foreach(HAS_HOT_LOOP false true)
|
||||
foreach(TAIL_NUMBER ODD EVEN)
|
||||
set(KERNEL_FILE mxgemm/mx_flatmm_instance_${PERSISTENT}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp)
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}
|
||||
@ONLY)
|
||||
list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE})
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
set(${FILE_LIST} ${${FILE_LIST}} PARENT_SCOPE)
|
||||
endfunction()
|
||||
53
example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in
Normal file
53
example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in
Normal file
@@ -0,0 +1,53 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "mx_flatmm_instance.hpp"
|
||||
|
||||
// clang-format off
|
||||
#define FLATMM_CONFIG @FLATMM_CONFIG@
|
||||
#define A_DATA_TYPE @A_DATA_TYPE@
|
||||
#define B_DATA_TYPE @B_DATA_TYPE@
|
||||
#define C_DATA_TYPE @C_DATA_TYPE@
|
||||
#define A_LAYOUT @A_LAYOUT@
|
||||
#define B_LAYOUT @B_LAYOUT@
|
||||
#define C_LAYOUT @C_LAYOUT@
|
||||
#define PERSISTENT @PERSISTENT@
|
||||
#define SPLIT_K @SPLIT_K@
|
||||
#define HAS_HOT_LOOP @HAS_HOT_LOOP@
|
||||
#define TAIL_NUMBER @TAIL_NUMBER@
|
||||
// clang-format on
|
||||
|
||||
using FP4 = ck_tile::pk_fp4_t;
|
||||
using FP16 = ck_tile::fp16_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
|
||||
using ROW = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using COL = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
inline constexpr auto ODD = ck_tile::TailNumber::Odd;
|
||||
inline constexpr auto EVEN = ck_tile::TailNumber::Even;
|
||||
|
||||
inline constexpr int ScaleGranularityM = 1;
|
||||
inline constexpr int ScaleGranularityN = 1;
|
||||
inline constexpr int ScaleGranularityK = 32;
|
||||
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>;
|
||||
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>;
|
||||
|
||||
template float mx_flatmm_calc<FLATMM_CONFIG,
|
||||
A_DATA_TYPE,
|
||||
B_DATA_TYPE,
|
||||
/*DsDatatype*/ ck_tile::tuple<>,
|
||||
/*AccDataType*/ float,
|
||||
C_DATA_TYPE,
|
||||
A_LAYOUT,
|
||||
B_LAYOUT,
|
||||
/*DsLayout*/ ck_tile::tuple<>,
|
||||
C_LAYOUT,
|
||||
ScaleM,
|
||||
ScaleN,
|
||||
PERSISTENT,
|
||||
/*CDEElementWise*/ ck_tile::element_wise::PassThrough,
|
||||
SPLIT_K,
|
||||
HAS_HOT_LOOP,
|
||||
TAIL_NUMBER>(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s);
|
||||
172
example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp
Normal file
172
example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp
Normal file
@@ -0,0 +1,172 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "mx_flatmm.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
using is_row_major_t = ck_tile::bool_constant<
|
||||
std::is_same_v<ck_tile::remove_cvref_t<Layout>, ck_tile::tensor_layout::gemm::RowMajor>>;
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise,
|
||||
bool Splitk,
|
||||
bool HasHotLoop,
|
||||
ck_tile::TailNumber TailNum>
|
||||
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using FlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using MXGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
persistent,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>;
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
|
||||
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
|
||||
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation =
|
||||
Splitk ? ck_tile::memory_operation_enum::atomic_add : ck_tile::memory_operation_enum::set;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
using MXPipelineProblem = ck_tile::MXFlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
FlatmmShape,
|
||||
MXGemmTraits,
|
||||
scheduler,
|
||||
HasHotLoop,
|
||||
TailNum>;
|
||||
|
||||
using MXFlatmmPipeline = ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1<MXPipelineProblem>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<FlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
using GemmEpilogue =
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
MXPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
|
||||
using Kernel = ck_tile::MXFlatmmKernel<TilePartitioner, MXFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << FlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << FlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << MXPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << MXFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major_t<ALayout>{}));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major_t<BLayout>{}));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
@@ -38,3 +38,23 @@ struct MXfp4_FlatmmConfig16
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise,
|
||||
bool Splitk,
|
||||
bool HasHotLoop,
|
||||
ck_tile::TailNumber TailNum>
|
||||
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s);
|
||||
|
||||
@@ -163,5 +163,5 @@ int run_mx_flatmm_with_layouts(int argc,
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
return pass ? 0 : -1;
|
||||
}
|
||||
|
||||
@@ -3,23 +3,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
// mfma_type, 0:32x32, 1:16x16
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / FlatmmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
FlatmmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
|
||||
@@ -114,7 +114,7 @@ int run_moe_gemm_example_with_layouts(int argc,
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
|
||||
|
||||
auto b_shuffle_host = shuffle_b<FlatmmConfig>(b_k_n_tensor);
|
||||
auto b_shuffle_host = flatmm_shuffle_b<FlatmmConfig>(b_k_n_tensor);
|
||||
|
||||
std::cout << "moe_flatmm:" //
|
||||
<< "\n num_experts: " << experts << "\n num_tokens: " << num_tokens
|
||||
|
||||
Reference in New Issue
Block a user