mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
This commit is contained in:
49
example/ck_tile/18_flatmm/CMakeLists.txt
Normal file
49
example/ck_tile/18_flatmm/CMakeLists.txt
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(SUPPORTED_GPUS gfx908 gfx90a gfx942 gfx950)
|
||||
|
||||
set(has_supported_gpu FALSE)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST SUPPORTED_GPUS)
|
||||
set(has_supported_gpu TRUE)
|
||||
break()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(has_supported_gpu)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
# ... because they are auto-generated
|
||||
set(EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template)
|
||||
set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS)
|
||||
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")
|
||||
add_executable(tile_example_flatmm_basic flatmm_basic.cpp)
|
||||
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_moe_flatmm moe_flatmm.cpp)
|
||||
target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_grouped_flatmm grouped_flatmm.cpp)
|
||||
target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx94")
|
||||
add_executable(tile_example_mixed_prec_flatmm mixed_prec/mixed_prec_flatmm.cpp)
|
||||
target_compile_options(tile_example_mixed_prec_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_a16w4_moe_flatmm mixed_prec/a16w4_moe_flatmm.cpp)
|
||||
target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
if (GPU_TARGETS MATCHES "gfx95")
|
||||
include(mxgemm/mx_flatmm_instance.cmake)
|
||||
mx_flatmm_instance_generate(EXAMPLE_MX_FLATMM_FILES)
|
||||
message(STATUS "Generated MX FlatMM kernel files: ${EXAMPLE_MX_FLATMM_FILES}")
|
||||
|
||||
add_executable(tile_example_mx_flatmm mxgemm/mx_flatmm.cpp ${EXAMPLE_MX_FLATMM_FILES})
|
||||
target_include_directories(tile_example_mx_flatmm PRIVATE mxgemm)
|
||||
target_compile_options(tile_example_mx_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
endif()
|
||||
79
example/ck_tile/18_flatmm/README.md
Normal file
79
example/ck_tile/18_flatmm/README.md
Normal file
@@ -0,0 +1,79 @@
|
||||
# FLATMM Matrix Multiplication with CK Tile
|
||||
|
||||
This example demonstrates FLATMM (flattened matrix multiplication) using the CK Tile programming model. FLATMM is a variant of GEMM optimized for certain memory layouts and batch processing patterns. Currently, it only supports the basic feature of the CK Tile FLATMM, but creates the placeholders for the future support on different FLATMM pipeline and different FLATMM modules. In the near future, we will gradually migrate all the FLATMM features from old CK to CK Tile.
|
||||
|
||||
---
|
||||
|
||||
## Algorithm and Math
|
||||
|
||||
Given:
|
||||
- $A$: $[\text{batch}, M, K]$
|
||||
- $B$: $[\text{batch}, K, N]$
|
||||
- $C$: $[\text{batch}, M, N]$
|
||||
|
||||
For each batch $b$:
|
||||
$$
|
||||
C^{(b)} = A^{(b)} \times B^{(b)}
|
||||
$$
|
||||
|
||||
- **FLATMM**: An alternative solution as the Preshuffled GEMM in /03_gemm
|
||||
|
||||
|
||||
---
|
||||
|
||||
## Build & Run
|
||||
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# The basic pipeline method on the flatmm calculation
|
||||
make tile_example_flatmm_basic -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_flatmm_basic`
|
||||
|
||||
### Arguments
|
||||
```
|
||||
args:
|
||||
-m m dimension (default:256)
|
||||
-n n dimension (default:256)
|
||||
-k k dimension (default:128)
|
||||
-a_layout A tensor data layout - Row by default (default:R)
|
||||
-b_layout B tensor data layout - Row by default (default:C)
|
||||
-c_layout C tensor data layout - Row by default (default:R)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-warmup number of iterations before benchmark the kernel (default:50)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-split_k splitK value (default:1)
|
||||
-init 0:random, 1:linear, 2:constant(1) (default:0)
|
||||
-warp_tile 0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only) (default:0)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:flatmm_basic.json)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Source Structure
|
||||
|
||||
- **Kernel**: [`flatmm_basic.hpp`](flatmm_basic.hpp) (tile-programming kernel template)
|
||||
- **Executable**: [`flatmm_basic.cpp`](flatmm_basic.cpp)
|
||||
- **Build**: `CMakeLists.txt`, `run_flatmm_example.inc`, `script/`
|
||||
|
||||
---
|
||||
|
||||
## Related CK Tile Examples
|
||||
|
||||
- [16_batched_gemm](../16_batched_gemm/README.md): Batched GEMM with tiles
|
||||
- [03_gemm](../03_gemm/README.md): Single GEMM with tiles
|
||||
- [17_grouped_gemm](../17_grouped_gemm/README.md): Grouped GEMM with tiles
|
||||
|
||||
For distribution, see [`include/ck_tile/tile_program/tile_distribution/`](../../../include/ck_tile/tile_program/tile_distribution/).
|
||||
|
||||
---
|
||||
[Back to CK Tile Examples](../README.md)
|
||||
511
example/ck_tile/18_flatmm/flatmm_basic.cpp
Normal file
511
example/ck_tile/18_flatmm/flatmm_basic.cpp
Normal file
@@ -0,0 +1,511 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <numeric>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "flatmm_basic.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
template <typename T>
|
||||
constexpr const char* DataTypeToString()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck_tile::half_t>)
|
||||
{
|
||||
return "fp16";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::fp8_t>)
|
||||
{
|
||||
return "fp8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::bf8_t>)
|
||||
{
|
||||
return "bf8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
return "bf16";
|
||||
}
|
||||
else
|
||||
{
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// mfma_type, 0:32x32, 1:16x16
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b_v0(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
constexpr int MaxVecSize = 16 / sizeof(T);
|
||||
constexpr int KLane = ck_tile::get_warp_size() / FlatmmConfig::N_Warp_Tile;
|
||||
constexpr int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane);
|
||||
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / ItemsPerAccess,
|
||||
ItemsPerAccess});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 1, 3});
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto shuffle_b_v1(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
constexpr int MaxVecSize = 16 / sizeof(T);
|
||||
constexpr int KLane = ck_tile::get_warp_size() / FlatmmConfig::N_Warp_Tile;
|
||||
constexpr int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane);
|
||||
constexpr int NRepeat = FlatmmConfig::N_Tile / FlatmmConfig::N_Warp_Tile / FlatmmConfig::N_Warp;
|
||||
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Tile,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / ItemsPerAccess,
|
||||
ItemsPerAccess});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 2, 5});
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise>
|
||||
float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
persistent,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN>>;
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel = ck_tile::FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool UsePersistentKernel = false,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::DeviceMem& b_shuffle_dev_buf,
|
||||
ck_tile::DeviceMem& c_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
ScaleM scale_m,
|
||||
ScaleN scale_n,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN> args = {a_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C,
|
||||
scale_m,
|
||||
scale_n};
|
||||
|
||||
float ave_time = flatmm_calc<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ScaleM,
|
||||
ScaleN,
|
||||
UsePersistentKernel,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Flatmm kernel with DataType = " << DataTypeToString<ADataType>()
|
||||
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
|
||||
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "256", "m dimension")
|
||||
.insert("n", "256", "n dimension")
|
||||
.insert("k", "128", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
|
||||
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
#include "run_flatmm_example.inc"
|
||||
|
||||
template <template <typename PreType> typename FlatmmConfig>
|
||||
int run_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
int scale_opt = arg_parser.get_int("scale");
|
||||
int persistent_opt = arg_parser.get_int("persistent");
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
if(scale_opt == 0)
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
-1,
|
||||
-1,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
1,
|
||||
1>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
1,
|
||||
1,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
if(scale_opt == 0)
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>, 1, 1>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig32>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 2)
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig16_950>(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
|
||||
}
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
160
example/ck_tile/18_flatmm/flatmm_basic.hpp
Normal file
160
example/ck_tile/18_flatmm/flatmm_basic.hpp
Normal file
@@ -0,0 +1,160 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
// GEMM config with 32x132 warp tile
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig32
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 32;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool TiledMMAPermuteN = false; // disable PermuteN when NWarpTile != 16
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig32_950 : public FlatmmConfig32<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 64;
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 64;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 4 == 0;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr int N_Repeat =
|
||||
N_Tile / FlatmmConfig16<DataType>::N_Warp_Tile / FlatmmConfig16<DataType>::N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 4 == 0;
|
||||
};
|
||||
|
||||
template <typename ADataType>
|
||||
struct GemmBasicTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
// ToDo: Add more bias config to support different categories of GEMM.
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::bf16_t;
|
||||
};
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
// ToDo: Add more bias config to support different categories of GEMM.
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_8bit_type
|
||||
: std::bool_constant<std::is_same_v<T, ck_tile::fp8_t> || std::is_same_v<T, ck_tile::bf8_t>>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise>
|
||||
float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s);
|
||||
345
example/ck_tile/18_flatmm/grouped_flatmm.cpp
Normal file
345
example/ck_tile/18_flatmm/grouped_flatmm.cpp
Normal file
@@ -0,0 +1,345 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <numeric>
|
||||
|
||||
#include "flatmm_basic.hpp"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("Ms", "1,1,1", "m dimension")
|
||||
.insert("Ns", "5120,5120,5120", "n dimension")
|
||||
.insert("Ks", "6144,6144,6144", "k dimension")
|
||||
.insert("group_count", "3", "group count")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("mode",
|
||||
"masked",
|
||||
"grouped gemm mode: [general | contiguous | masked], general by default")
|
||||
.insert("wave_tile", "16", "only support 16(16x16) or 32(32x32)")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("scale", "0", "0:without scale, 1:per-token/channel scale, only for fp8/bf8")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
bool persistent,
|
||||
typename CDEElementWise,
|
||||
typename KernelArguments>
|
||||
float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
persistent,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups>>;
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel =
|
||||
ck_tile::GroupedFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.group_count * args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.group_count * args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_shuffle_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(args.e_ptr,
|
||||
0,
|
||||
args.group_count * args.M * args.N * sizeof(CDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_grouped_flatmm_example.inc"
|
||||
|
||||
template <template <typename PreType> typename FlatmmConfig>
|
||||
int run_grouped_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string mode = arg_parser.get_str("mode");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(mode == "contiguous")
|
||||
{
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::half_t,
|
||||
FlatmmConfig<ck_tile::half_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::bf16_t,
|
||||
FlatmmConfig<ck_tile::bf16_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
run_contiguous_grouped_flatmm_example_with_layouts<ck_tile::bf8_t,
|
||||
FlatmmConfig<ck_tile::bf8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else if(mode == "masked")
|
||||
{
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
run_masked_grouped_flatmm_example_with_layouts<ck_tile::half_t,
|
||||
FlatmmConfig<ck_tile::half_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
run_masked_grouped_flatmm_example_with_layouts<ck_tile::bf16_t,
|
||||
FlatmmConfig<ck_tile::bf16_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
run_masked_grouped_flatmm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
run_masked_grouped_flatmm_example_with_layouts<ck_tile::bf8_t,
|
||||
FlatmmConfig<ck_tile::bf8_t>>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported mode!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_grouped_flatmm_example<FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
// else if(warp_tile == 1)
|
||||
// {
|
||||
// return !run_grouped_flatmm_example<FlatmmConfig32>(argc, argv);
|
||||
// }
|
||||
// else if(warp_tile == 2)
|
||||
// {
|
||||
// return !run_grouped_flatmm_example<FlatmmConfig16_950>(argc, argv);
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// return !run_grouped_flatmm_example<FlatmmConfig32_950>(argc, argv);
|
||||
// }
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
49
example/ck_tile/18_flatmm/mixed_prec/a16w4_flatmm.hpp
Normal file
49
example/ck_tile/18_flatmm/mixed_prec/a16w4_flatmm.hpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr int N_Repeat =
|
||||
N_Tile / A16W4_FlatmmConfig16::N_Warp_Tile / A16W4_FlatmmConfig16::N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
515
example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp
Normal file
515
example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp
Normal file
@@ -0,0 +1,515 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "a16w4_moe_flatmm.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/moe_flatmm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/reference/reference_moe_gemm.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// gemm1
|
||||
// operand-A = [num_token, d_model]
|
||||
// operand-B = [num_expert, hidden, d_model]
|
||||
// operand-C = [num_token, topk, hidden]
|
||||
|
||||
// gemm2
|
||||
// operand-A = [num_token, topk, hidden]
|
||||
// operand-B = [num_expert, d_model, hidden]
|
||||
// operand-C = [num_token, d_model]
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ck_tile::MoeFlatmmKind moe_kind = ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename MoeFlatmmHostArgs>
|
||||
float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
false, // UsePersistentKernel_
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>; // Preshuffle_
|
||||
|
||||
constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, ck_tile::pk_fp4_t>;
|
||||
|
||||
if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
static_assert(
|
||||
FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0,
|
||||
"requires NRepeat is multiple of 2 for FFN_gemm1_gate_up");
|
||||
}
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
|
||||
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
|
||||
using CodegenPipelineProblem =
|
||||
std::conditional_t<MXFP4_Pipeline,
|
||||
ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
|
||||
using CodegenFlatmmPipeline = std::conditional_t<
|
||||
MXFP4_Pipeline,
|
||||
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>,
|
||||
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>>;
|
||||
using FusedAct =
|
||||
std::conditional_t<MXFP4_Pipeline, ck_tile::moe::Swiglu, ck_tile::moe::MoeSilu>;
|
||||
|
||||
using Kernel = ck_tile::MoeFlatmmKernel<TilePartitioner,
|
||||
CodegenFlatmmPipeline,
|
||||
GemmEpilogue,
|
||||
moe_kind,
|
||||
FusedAct>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel " << Kernel::GetName() << "\n"
|
||||
<< "with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< "\n"
|
||||
<< "k_batch: " << kargs.k_batch << std::endl;
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>
|
||||
? 2
|
||||
: 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>
|
||||
? 2
|
||||
: 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2 ? args.NumTokens * args.TopK
|
||||
: args.NumTokens,
|
||||
args.K,
|
||||
args.stride_A,
|
||||
is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N * args.NumExperts, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
const int outputN =
|
||||
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? args.N / 2 : args.N;
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.NumTokens * args.N * sizeof(CDataType), s.stream_id_));
|
||||
else if(args.k_batch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(args.e_ptr,
|
||||
0,
|
||||
args.NumTokens * args.TopK * outputN * sizeof(CDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, ck_tile::MoeFlatmmKind moe_kind, class IterSrc, class IterDst>
|
||||
void shuffle_mxfp4_weight(const IterSrc src, IterDst dst, int experts_cnt, int N, int K)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = FlatmmConfig::N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
|
||||
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
int up_stride = N / 2 / NLane;
|
||||
|
||||
for(long eid = 0; eid < experts_cnt; ++eid)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K_pk; ++k)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
// interleave gate and up part with granularity is 16.
|
||||
int n0_interleave = n >= N / 2 ? (n0 - up_stride) * 2 + 1 : // up part
|
||||
n0 * 2; // gate part
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
long outputIndex = eid * N * K_pk + n0_interleave * KPack * NLane * KLane * K0 +
|
||||
k0 * KPack * NLane * KLane + k1 * KPack * NLane +
|
||||
n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[eid * N * K_pk + n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(long eid = 0; eid < experts_cnt; ++eid)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K_pk; ++k)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
long outputIndex = eid * N * K_pk + n0 * KPack * NLane * KLane * K0 +
|
||||
k0 * KPack * NLane * KLane + k1 * KPack * NLane +
|
||||
n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[eid * N * K_pk + n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig, ck_tile::MoeFlatmmKind moe_kind, typename T>
|
||||
auto shuffle_mxfp4_scale(const ck_tile::HostTensor<T>& scale, int experts_cnt)
|
||||
{
|
||||
assert(scale.get_lengths().size() == 2);
|
||||
int n_ = scale.get_lengths()[1];
|
||||
int k_ = scale.get_lengths()[0];
|
||||
|
||||
int k_per_expert = k_ / experts_cnt;
|
||||
|
||||
constexpr int K_Pack = 2; // fixed for mxfp4
|
||||
constexpr int N_Pack = 2; // fixed for mxfp4
|
||||
constexpr int GranularityK = 32; // fixed for mxfp4
|
||||
|
||||
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
|
||||
|
||||
static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16");
|
||||
static_assert(FlatmmConfig::N_Repeat % N_Pack == 0);
|
||||
static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0);
|
||||
|
||||
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
ck_tile::HostTensor<T> shfl_scale({
|
||||
experts_cnt,
|
||||
k_per_expert / K_Pack / K_Lane,
|
||||
K_Pack,
|
||||
K_Lane,
|
||||
N_Pack, // N_Pack = 2 is composed of Gate + Up.
|
||||
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
});
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
return ck_tile::reference_permute(shfl_scale, {0, 5, 1, 3, 6, 2, 4});
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::HostTensor<T> shfl_scale({
|
||||
experts_cnt,
|
||||
k_per_expert / K_Pack / K_Lane,
|
||||
K_Pack,
|
||||
K_Lane,
|
||||
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
|
||||
N_Pack,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
});
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
return ck_tile::reference_permute(shfl_scale, {0, 4, 1, 3, 6, 2, 5});
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_a16w4_moe_flatmm_example.inc"
|
||||
|
||||
template <typename FlatmmConfig>
|
||||
int run_a16w4_moe_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
const std::string mixed_prec = arg_parser.get_str("mixed_prec");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
const std::string gemm_kind = arg_parser.get_str("gemm_kind");
|
||||
if(gemm_kind == "gemm1_gate_up")
|
||||
{
|
||||
if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm2")
|
||||
{
|
||||
if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<ck_tile::half_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<ck_tile::bfloat16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm2!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm1_split_k")
|
||||
{
|
||||
if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_split_k!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
|
||||
"[gemm1_gate_up | gemm1_split_k | gemm2]");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
// else if(warp_tile == 1)
|
||||
// {
|
||||
// return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
|
||||
// }
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
88
example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp
Normal file
88
example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.hpp
Normal file
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/moe_flatmm.hpp"
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr int N_Repeat =
|
||||
N_Tile / A16W4_FlatmmConfig16::N_Warp_Tile / A16W4_FlatmmConfig16::N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("experts", "8", "Num of experts - 8 by default")
|
||||
.insert("NumTokens", "128", "M dimensions - 128 by default.")
|
||||
.insert("TopK", "3", "Top K - 3 by default.")
|
||||
.insert("N", "4096", "N dimensions - 4096 by default.")
|
||||
.insert("K", "4096", "K dimensions - 4096 by default.")
|
||||
.insert("stride_A", "", "Tensor A strides - it is empty by default.")
|
||||
.insert("stride_B", "", "Tensor B strides - it is empty by default.")
|
||||
.insert("stride_C", "", "Tensor C strides - it is empty by default.")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default.")
|
||||
.insert("b_layout", "C", "B tensor data layout - Col by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("gemm_kind",
|
||||
"gemm1_gate_up",
|
||||
"Gemm kind in FFN network [gemm1_gate_up | gemm2 | gemm1_split_k] - "
|
||||
"gemm1_gate_up by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("mixed_prec",
|
||||
"bf16xfp4",
|
||||
"data type for activation and weight, support: bf16xfp4, fp16xfp4")
|
||||
.insert("init", "0", "0:random, 1:constant(1)")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 16x16 (950 only, may use a larger tile than warp_tile=0)")
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel.")
|
||||
.insert("k_batch", "1", "parallism to control splik-k.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
462
example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp
Normal file
462
example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp
Normal file
@@ -0,0 +1,462 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "mixed_prec_flatmm.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise>
|
||||
float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
persistent,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>;
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
|
||||
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
|
||||
using Kernel =
|
||||
ck_tile::F16xMXF4FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ScaleN,
|
||||
bool UsePersistentKernel = false,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_mixed_prec_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::DeviceMem& b_shuffle_dev_buf,
|
||||
ck_tile::DeviceMem& c_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
ScaleN dequant_scale_n,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
// Activation has no scale
|
||||
using ActScaleType = ck_tile::FlatmmScalePointer<-1>;
|
||||
|
||||
ck_tile::ScaleFlatmmHostArgs<ActScaleType, ScaleN> args = {a_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C,
|
||||
{},
|
||||
dequant_scale_n};
|
||||
|
||||
float ave_time = mixed_prec_flatmm_calc<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ActScaleType,
|
||||
ScaleN,
|
||||
UsePersistentKernel,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
constexpr int PackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * N * K / PackedSize +
|
||||
sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run A16W4_Flatmm kernel " << " M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
|
||||
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "256", "m dimension")
|
||||
.insert("n", "256", "n dimension")
|
||||
.insert("k", "512", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on GPU")
|
||||
.insert("mixed_prec",
|
||||
"bf16xfp4",
|
||||
"data type for activation and weight, support: bf16xfp4, fp16xfp4")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:constant(1)")
|
||||
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
|
||||
.insert("warp_tile",
|
||||
"0",
|
||||
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, class IterSrc, class IterDst>
|
||||
void preShuffleWeight(const IterSrc src, IterDst dst, int N, int K)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = FlatmmConfig::N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K_pk; ++k)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, class T>
|
||||
auto preShuffleScale(const ck_tile::HostTensor<T>& scale)
|
||||
{
|
||||
assert(scale.get_lengths().size() == 2);
|
||||
int n_ = scale.get_lengths()[1];
|
||||
int k_ = scale.get_lengths()[0];
|
||||
|
||||
constexpr int K_Pack = 2; // fixed for mxfp4
|
||||
constexpr int N_Pack = 2; // fixed for mxfp4
|
||||
constexpr int GranularityK = 32; // fixed for mxfp4
|
||||
|
||||
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
|
||||
|
||||
static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16");
|
||||
static_assert(FlatmmConfig::N_Repeat % N_Pack == 0);
|
||||
static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0);
|
||||
|
||||
ck_tile::HostTensor<T> shfl_scale({
|
||||
k_ / K_Pack / K_Lane,
|
||||
K_Pack,
|
||||
K_Lane,
|
||||
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
|
||||
N_Pack,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
});
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
return ck_tile::reference_permute(shfl_scale, {3, 0, 2, 5, 1, 4});
|
||||
}
|
||||
|
||||
#include "run_mixed_prec_flatmm.inc"
|
||||
|
||||
template <typename FlatmmConfig>
|
||||
int run_mixed_prec_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string mixed_prec = arg_parser.get_str("mixed_prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
int persistent_opt = arg_parser.get_int("persistent");
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::bf16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
false>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
run_mixed_prec_flatmm_with_layouts<ck_tile::fp16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
true>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
return !run_mixed_prec_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported warp_tile!");
|
||||
}
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
14
example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.hpp
Normal file
14
example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.hpp
Normal file
@@ -0,0 +1,14 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include "a16w4_flatmm.hpp"
|
||||
@@ -0,0 +1,358 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ck_tile::MoeFlatmmKind kind,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename MoeHostArgs>
|
||||
float invoke_a16w4_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)
|
||||
{
|
||||
float ave_time = a16w4_moe_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
kind,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::string op_name{"Moe Gemm"};
|
||||
|
||||
constexpr int PackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
|
||||
sizeof(BDataType) * args.N * args.K / PackedSize +
|
||||
sizeof(CDataType) * args.M * args.N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename PrecActType,
|
||||
typename PrecWeightType,
|
||||
typename FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind kind,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using CDataType =
|
||||
std::conditional_t<kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k, float, PrecActType>;
|
||||
using AccDataType = float;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
|
||||
constexpr int ScaleGranularityN = 1;
|
||||
constexpr int ScaleGranularityK = 32;
|
||||
|
||||
const ck_tile::index_t N = arg_parser.get_int("N");
|
||||
const ck_tile::index_t K = arg_parser.get_int("K");
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_A");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_B");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_C");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
const ck_tile::index_t num_tokens = arg_parser.get_int("NumTokens");
|
||||
const ck_tile::index_t topk = arg_parser.get_int("TopK");
|
||||
const ck_tile::index_t warmup = arg_parser.get_int("warmup");
|
||||
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
|
||||
const ck_tile::index_t experts = arg_parser.get_int("experts");
|
||||
const ck_tile::index_t k_batch = arg_parser.get_int("k_batch");
|
||||
|
||||
// TODO: replace the magic declaration
|
||||
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
|
||||
|
||||
ck_tile::index_t sorted_tile_num = (num_tokens + MPerBlock - 1) / MPerBlock * MPerBlock * topk;
|
||||
ck_tile::index_t valid_tile_num = sorted_tile_num;
|
||||
ck_tile::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
|
||||
const ck_tile::index_t M = sorted_tile_num * MPerBlock;
|
||||
const ck_tile::index_t outputN = kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? N / 2 : N;
|
||||
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr bool IsInputGemm = kind != ck_tile::MoeFlatmmKind::kFFN_gemm2;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(
|
||||
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(
|
||||
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
auto a_m_k_tensor = ck_tile::HostTensor<ADataType>(ck_tile::host_tensor_descriptor(
|
||||
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout)));
|
||||
auto b_k_n_tensor = ck_tile::HostTensor<BDataType>(
|
||||
is_row_major(b_layout)
|
||||
? ck_tile::host_tensor_descriptor(experts * N, K, stride_B, is_row_major(b_layout))
|
||||
: ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
|
||||
auto c_m_n_tensor = ck_tile::HostTensor<CDataType>(ck_tile::host_tensor_descriptor(
|
||||
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::HostTensorDescriptor(
|
||||
{K * experts / ScaleGranularityK, N / ScaleGranularityN}, {N / ScaleGranularityN, 1}));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{0.f, 1.f}(scale_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.0f, 1.0f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.0f, 1.0f}(b_k_n_tensor);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.0f, 1.0f}(scale_b);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host(
|
||||
ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
|
||||
shuffle_mxfp4_weight<FlatmmConfig, kind>(
|
||||
b_k_n_tensor.begin(), b_shuffle_host.begin(), experts, N, K);
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b_shuffle =
|
||||
shuffle_mxfp4_scale<FlatmmConfig, kind>(scale_b, experts);
|
||||
ck_tile::DeviceMem scale_b_shuffle_dev_buf(scale_b_shuffle.get_element_space_size_in_bytes());
|
||||
|
||||
std::cout << "moe_flatmm:" << "\n num_experts: " << experts << "\n num_tokens: " << num_tokens
|
||||
<< "\n topk: " << topk << "\n sorted_tile_num: " << sorted_tile_num
|
||||
<< "\n problem_n: " << N << "\n problem_k: " << K
|
||||
<< "\n a_m_k: " << a_m_k_tensor.mDesc << "\n b_k_n: " << b_k_n_tensor.mDesc
|
||||
<< "\n b_shuffle: " << b_shuffle_host.mDesc << "\n c_m_n: " << c_m_n_tensor.mDesc
|
||||
<< std::endl;
|
||||
|
||||
ck_tile::HostTensor<ck_tile::index_t> expert_ids(
|
||||
ck_tile::HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> sorted_token_ids(
|
||||
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> expert_weight(
|
||||
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> max_token_id(
|
||||
ck_tile::HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
ck_tile::HostTensor<AccDataType> expert_bias(ck_tile::HostTensorDescriptor({experts * N}, {1}));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
// for verification only, no need to satify weight normalization
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 1.0f}(expert_weight);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.0f, 1.0f}(expert_bias);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.0f, 1.0f}(expert_weight);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 0.0f}(expert_bias);
|
||||
}
|
||||
|
||||
max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8};
|
||||
// int eids[] = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts);
|
||||
}
|
||||
|
||||
int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
// int token_per_tile = num_tokens * topk / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
// sorted_token_ids.mData[0] = 0;
|
||||
for(int i = 0; i < sorted_tile_num * MPerBlock; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile && tokenid < num_tokens * topk)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = num_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf{a_m_k_tensor.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem b_origin_dev_buf{b_k_n_tensor.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf{b_shuffle_host.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem c_m_n_dev_buf{c_m_n_tensor.get_element_space_size_in_bytes()};
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_tensor.data());
|
||||
b_origin_dev_buf.ToDevice(b_k_n_tensor.data());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_tensor.SetZero();
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_dev{sorted_token_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_ids_dev{expert_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem max_token_id_dev{max_token_id.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_weight_dev{expert_weight.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_bias_dev{expert_bias.get_element_space_size_in_bytes()};
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.data());
|
||||
expert_weight_dev.ToDevice(expert_weight.data());
|
||||
expert_bias_dev.ToDevice(expert_bias.data());
|
||||
scale_b_shuffle_dev_buf.ToDevice(scale_b_shuffle.data());
|
||||
|
||||
const ck_tile::index_t* p_sorted_token_ids_dev =
|
||||
static_cast<ck_tile::index_t*>(sorted_token_ids_dev.GetDeviceBuffer());
|
||||
const ck_tile::index_t* p_expert_ids_dev =
|
||||
static_cast<ck_tile::index_t*>(expert_ids_dev.GetDeviceBuffer());
|
||||
const ck_tile::index_t* p_max_token_id_dev =
|
||||
static_cast<ck_tile::index_t*>(max_token_id_dev.GetDeviceBuffer());
|
||||
const AccDataType* p_sorted_expert_weight_dev =
|
||||
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
|
||||
|
||||
auto scale_b_shuffle_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()),
|
||||
N / ScaleGranularityN};
|
||||
auto exp_bias_dev_ptr = ck_tile::FlatmmScalePointer<1>{
|
||||
static_cast<float*>(expert_bias_dev.GetDeviceBuffer()), experts * N};
|
||||
|
||||
using MoeFlatmmArgs = ck_tile::MoeFlatmmHostArgs<
|
||||
ck_tile::FlatmmScalePointer<-1>,
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>,
|
||||
ck_tile::FlatmmScalePointer<1>>;
|
||||
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
|
||||
p_sorted_expert_weight_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
num_tokens,
|
||||
experts,
|
||||
topk,
|
||||
k_batch, // k_batch
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
nullptr,
|
||||
scale_b_shuffle_dev_ptr,
|
||||
exp_bias_dev_ptr};
|
||||
|
||||
invoke_a16w4_moe_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
kind>(warmup, repeat, gemm_desc);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_tensor.data());
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("validate"))
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(IsInputGemm ? num_tokens * topk : num_tokens,
|
||||
outputN,
|
||||
stride_C,
|
||||
is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::HostTensor<AccDataType> scale_A(
|
||||
ck_tile::HostTensorDescriptor({1, K / ScaleGranularityK}, {1, 1}));
|
||||
|
||||
// scaleA = 1 has no effect on the result
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(scale_A);
|
||||
ck_tile::DeviceMem scale_A_dev_buf(scale_A.get_element_space_size_in_bytes());
|
||||
scale_A_dev_buf.ToDevice(scale_A.data());
|
||||
|
||||
// convert scale_b from e8m0 to float
|
||||
ck_tile::HostTensor<AccDataType> scale_b_float(ck_tile::HostTensorDescriptor(
|
||||
{K * experts / ScaleGranularityK, N / ScaleGranularityN}, {N / ScaleGranularityN, 1}));
|
||||
std::copy(scale_b.begin(), scale_b.end(), scale_b_float.begin());
|
||||
ck_tile::DeviceMem scale_b_float_dev_buf(scale_b_float.get_element_space_size_in_bytes());
|
||||
scale_b_float_dev_buf.ToDevice(scale_b_float.data());
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> c_m_n_ref_buf =
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
|
||||
c_m_n_ref_buf->SetZero();
|
||||
|
||||
ck_tile::reference_moe_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
static_cast<int>(kind),
|
||||
ck_tile::moe::Swiglu>(
|
||||
p_sorted_token_ids_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
static_cast<const ADataType*>(a_m_k_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<const BDataType*>(b_origin_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_ref_buf->GetDeviceBuffer()),
|
||||
p_sorted_expert_weight_dev,
|
||||
num_tokens,
|
||||
MPerBlock,
|
||||
topk,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
M,
|
||||
1,
|
||||
ScaleGranularityK,
|
||||
static_cast<float*>(scale_A_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(scale_b_float_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(expert_bias_dev.GetDeviceBuffer()));
|
||||
|
||||
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
|
||||
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
const float atol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_m_n_tensor, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
181
example/ck_tile/18_flatmm/mixed_prec/run_mixed_prec_flatmm.inc
Normal file
181
example/ck_tile/18_flatmm/mixed_prec/run_mixed_prec_flatmm.inc
Normal file
@@ -0,0 +1,181 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
template <typename PrecActType,
|
||||
typename PrecWeightType,
|
||||
typename FlatmmConfig,
|
||||
bool UsePersistentKernel = false,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_mixed_prec_flatmm_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using CDataType = PrecActType;
|
||||
using AccDataType = float;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
|
||||
constexpr int DequantGranularityN = 1;
|
||||
constexpr int DequantGranularityK = 32;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
ck_tile::index_t n_warmup = arg_parser.get_int("warmup");
|
||||
ck_tile::index_t n_repeat = arg_parser.get_int("repeat");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_host(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_origin_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_rslt_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::HostTensorDescriptor(
|
||||
{K / DequantGranularityK, N / DequantGranularityN}, {N / DequantGranularityN, 1}));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{-2.f, 2.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.f, 1.f}(scale_b);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
preShuffleWeight<FlatmmConfig>(b_origin_host.begin(), b_shuffle_host.begin(), N, K);
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b_shuffle = preShuffleScale<FlatmmConfig>(scale_b);
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffle.get_element_space_size_in_bytes());
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
c_rslt_host.SetZero();
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffle.data());
|
||||
|
||||
auto scale_b_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN};
|
||||
|
||||
invoke_mixed_prec_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
decltype(scale_b_dev_ptr),
|
||||
UsePersistentKernel>(a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
scale_b_dev_ptr,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_dev_buf.FromDevice(c_rslt_host.data());
|
||||
|
||||
bool pass = true;
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::DeviceMem b_origin_dev_buf(b_origin_host.get_element_space_size_in_bytes());
|
||||
b_origin_dev_buf.ToDevice(b_origin_host.data());
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
ck_tile::DeviceMem c_gpu_ref_dev_buf(c_gpu_ref_host.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::HostTensor<AccDataType> scale_A(
|
||||
ck_tile::HostTensorDescriptor({1, K / DequantGranularityK}, {1, 1}));
|
||||
|
||||
// scaleA = 1 has no effect on the result
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(scale_A);
|
||||
ck_tile::DeviceMem scale_A_dev_buf(scale_A.get_element_space_size_in_bytes());
|
||||
scale_A_dev_buf.ToDevice(scale_A.data());
|
||||
|
||||
// convert scale_b from e8m0 to float
|
||||
ck_tile::HostTensor<AccDataType> scale_b_float(ck_tile::HostTensorDescriptor(
|
||||
{K / DequantGranularityK, N / DequantGranularityN}, {N / DequantGranularityN, 1}));
|
||||
std::copy(scale_b.begin(), scale_b.end(), scale_b_float.begin());
|
||||
ck_tile::DeviceMem scale_b_float_dev_buf(scale_b_float.get_element_space_size_in_bytes());
|
||||
scale_b_float_dev_buf.ToDevice(scale_b_float.data());
|
||||
|
||||
c_gpu_ref_dev_buf.SetZero();
|
||||
ck_tile::reference_blockwise_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
static_cast<ADataType*>(a_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_origin_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_gpu_ref_dev_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
M,
|
||||
DequantGranularityN,
|
||||
DequantGranularityK,
|
||||
static_cast<float*>(scale_A_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(scale_b_float_dev_buf.GetDeviceBuffer()));
|
||||
|
||||
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
|
||||
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
|
||||
const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_rslt_host, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
447
example/ck_tile/18_flatmm/moe_flatmm.cpp
Normal file
447
example/ck_tile/18_flatmm/moe_flatmm.cpp
Normal file
@@ -0,0 +1,447 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "moe_flatmm.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/moe_flatmm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/reference/reference_moe_gemm.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig, typename T>
|
||||
auto flatmm_shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
constexpr int MaxVecSize = 16 / sizeof(T);
|
||||
constexpr int KLane = ck_tile::get_warp_size() / FlatmmConfig::N_Warp_Tile;
|
||||
constexpr int ItemsPerAccess = std::min(MaxVecSize, FlatmmConfig::K_Warp_Tile / KLane);
|
||||
|
||||
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
k_ / ItemsPerAccess,
|
||||
ItemsPerAccess});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 1, 3});
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
// gemm1
|
||||
// operand-A = [num_token, d_model]
|
||||
// operand-B = [num_expert, hidden, d_model]
|
||||
// operand-C = [num_token, topk, hidden]
|
||||
|
||||
// gemm2
|
||||
// operand-A = [num_token, topk, hidden]
|
||||
// operand-B = [num_expert, d_model, hidden]
|
||||
// operand-C = [num_token, d_model]
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ck_tile::MoeFlatmmKind moe_kind = ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename ScaleM,
|
||||
typename ScaleN>
|
||||
float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
false, // UsePersistentKernel_
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>; // Preshuffle_
|
||||
|
||||
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
static_assert(
|
||||
FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0,
|
||||
"requires NRepeat is multiple of 2 for FFN_gemm1_gate_up");
|
||||
}
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up
|
||||
? 2
|
||||
: 1; // determined by scale shuffle pattern
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using Kernel = ck_tile::
|
||||
MoeFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue, moe_kind>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2 ? args.NumTokens * args.TopK
|
||||
: args.NumTokens,
|
||||
args.K,
|
||||
args.stride_A,
|
||||
is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N * args.NumExperts, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
const int outputN =
|
||||
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? args.N / 2 : args.N;
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.NumTokens * args.N * sizeof(CDataType), s.stream_id_));
|
||||
else if(args.k_batch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(args.e_ptr,
|
||||
0,
|
||||
args.NumTokens * args.TopK * outputN * sizeof(CDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
};
|
||||
|
||||
float ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_moe_flatmm_example.inc"
|
||||
|
||||
template <template <typename PreType> typename FlatmmConfig>
|
||||
int run_moe_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
const std::string prec_type = arg_parser.get_str("prec");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
const std::string gemm_kind = arg_parser.get_str("gemm_kind");
|
||||
if(gemm_kind == "gemm1_gate_up")
|
||||
{
|
||||
if(prec_type == "fp8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::bf8_t,
|
||||
FlatmmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
FlatmmConfig<ck_tile::bfloat16_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "fp16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
FlatmmConfig<ck_tile::half_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm1_gate_only")
|
||||
{
|
||||
if(prec_type == "fp8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::bf8_t,
|
||||
FlatmmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
FlatmmConfig<ck_tile::bfloat16_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "fp16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
FlatmmConfig<ck_tile::half_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm2")
|
||||
{
|
||||
if(prec_type == "fp8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<ck_tile::fp8_t,
|
||||
FlatmmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf8")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<ck_tile::bf8_t,
|
||||
FlatmmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "bf16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<ck_tile::bfloat16_t,
|
||||
FlatmmConfig<ck_tile::bfloat16_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(prec_type == "fp16")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts<ck_tile::half_t,
|
||||
FlatmmConfig<ck_tile::half_t>,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
|
||||
"[gemm1_gate_only | gemm1_gate_up | gemm2]");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_moe_flatmm_example<FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
return !run_moe_flatmm_example<FlatmmConfig32>(argc, argv);
|
||||
}
|
||||
else if(warp_tile == 2)
|
||||
{
|
||||
return !run_moe_flatmm_example<FlatmmConfig16_950>(argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return !run_moe_flatmm_example<FlatmmConfig32_950>(argc, argv);
|
||||
}
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
170
example/ck_tile/18_flatmm/moe_flatmm.hpp
Normal file
170
example/ck_tile/18_flatmm/moe_flatmm.hpp
Normal file
@@ -0,0 +1,170 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/moe_flatmm.hpp"
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig32
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 32;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool TiledMMAPermuteN = false; // disable PermuteN when NWarpTile != 16
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig32_950 : public FlatmmConfig32<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 64;
|
||||
};
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 64;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr int N_Repeat =
|
||||
N_Tile / FlatmmConfig16<DataType>::N_Warp_Tile / FlatmmConfig16<DataType>::N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false; // N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename ADataType>
|
||||
struct GemmBasicTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
// ToDo: Add more bias config to support different categories of GEMM.
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::bf16_t;
|
||||
};
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
// ToDo: Add more bias config to support different categories of GEMM.
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmBasicTypeConfig<ck_tile::bf8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct is_8bit_type
|
||||
: std::bool_constant<std::is_same_v<T, ck_tile::fp8_t> || std::is_same_v<T, ck_tile::bf8_t>>
|
||||
{
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("experts", "8", "Num of experts - 8 by default")
|
||||
.insert("NumTokens", "128", "M dimensions - 128 by default.")
|
||||
.insert("TopK", "3", "Top K - 3 by default.")
|
||||
.insert("N", "4096", "N dimensions - 4096 by default.")
|
||||
.insert("K", "4096", "K dimensions - 4096 by default.")
|
||||
.insert("stride_A", "", "Tensor A strides - it is empty by default.")
|
||||
.insert("stride_B", "", "Tensor B strides - it is empty by default.")
|
||||
.insert("stride_C", "", "Tensor C strides - it is empty by default.")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default.")
|
||||
.insert("b_layout", "C", "B tensor data layout - Col by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("gemm_kind",
|
||||
"gemm1_gate_only",
|
||||
"Gemm kind in FFN network [gemm1_gate_only | gemm1_gate_up | gemm2] - "
|
||||
"gemm1_gate_only by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert(
|
||||
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
278
example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp
Normal file
278
example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp
Normal file
@@ -0,0 +1,278 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "mx_flatmm.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename MXFlatmmArchTraits,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ScaleA,
|
||||
typename ScaleB,
|
||||
bool UsePersistentKernel = false,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
|
||||
ck_tile::DeviceMem& b_shuffle_dev_buf,
|
||||
ck_tile::DeviceMem& c_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ScaleA scale_a,
|
||||
ScaleB scale_b,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
using FlatmmConfig = typename MXFlatmmArchTraits::Config;
|
||||
|
||||
ck_tile::ScaleFlatmmHostArgs<ScaleA, ScaleB> args = {a_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_dev_buf.GetDeviceBuffer(),
|
||||
1,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C,
|
||||
scale_a,
|
||||
scale_b};
|
||||
|
||||
using FlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<FlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, FlatmmShape, Traits>;
|
||||
|
||||
using BaseFlatmmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * k_grain;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(k_split);
|
||||
const bool has_hot_loop = BaseFlatmmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseFlatmmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time = BaseFlatmmPipeline::template TailHandler<true>(
|
||||
[&](auto has_hot_loop_, auto tail_num_) {
|
||||
constexpr auto has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_num_v = tail_num_.value;
|
||||
return mx_flatmm_calc<MXFlatmmArchTraits,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ScaleA,
|
||||
ScaleB,
|
||||
UsePersistentKernel,
|
||||
CDEElementWise,
|
||||
false,
|
||||
has_hot_loop_v,
|
||||
tail_num_v>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
},
|
||||
has_hot_loop,
|
||||
tail_num);
|
||||
|
||||
constexpr int APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr int BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / 32;
|
||||
std::size_t num_byte = sizeof(ADataType) * M * K / APackedSize +
|
||||
sizeof(BDataType) * N * K / BPackedSize + sizeof(CDataType) * M * N +
|
||||
sizeof(ck_tile::e8m0_t) * M * K / 32 +
|
||||
sizeof(ck_tile::e8m0_t) * N * K / 32;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run " << ck_tile::gemm_prec_str<ADataType, BDataType>() << " Flatmm kernel " //
|
||||
<< " M = " << M << " N = " << N << " K = " << K << " StrideA = " << stride_A
|
||||
<< " StrideB = " << stride_B << " StrideC = " << stride_C << " : " << ave_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "32", "m dimension")
|
||||
.insert("n", "512", "n dimension")
|
||||
.insert("k", "256", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("mx_prec",
|
||||
"fp4xfp4",
|
||||
"data type for activation and weight, support: fp4xfp4, fp6xfp6, fp8xfp8, fp8xfp4 "
|
||||
"and fp4xfp8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("init", "0", "0:random, 1:constant(1)")
|
||||
.insert("persistent", "0", "0: no persistent, 1: persistent kernel")
|
||||
.insert("warp_tile", "0", "0: 16x16x128 on gfx950.");
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
#include "run_mx_flatmm.inc"
|
||||
|
||||
int run_mx_flatmm_example(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
std::string mx_prec = arg_parser.get_str("mx_prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
int persistent_opt = arg_parser.get_int("persistent");
|
||||
|
||||
std::cout << "Using default warptile of 16x16x128." << std::endl;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
if(mx_prec == "fp4" || mx_prec == "fp4xfp4")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::fp16_t,
|
||||
MXFlatmm_GFX950_FP4FP4_Traits,
|
||||
false>(arg_parser, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only non-persistent kernels are supported currently!");
|
||||
}
|
||||
else if(mx_prec == "fp6" || mx_prec == "fp6xfp6")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
return run_mx_flatmm_with_layouts<ck_tile::pk_fp6x16_t,
|
||||
ck_tile::pk_fp6x16_t,
|
||||
ck_tile::fp16_t,
|
||||
MXFlatmm_GFX950_FP6FP6_Traits,
|
||||
false>(arg_parser, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only support non-persistent kernel now!");
|
||||
}
|
||||
else if(mx_prec == "fp8" || mx_prec == "fp8xfp8")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp16_t,
|
||||
MXFlatmm_GFX950_FP8FP8_Traits,
|
||||
false>(arg_parser, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only support non-persistent kernel now!");
|
||||
}
|
||||
else if(mx_prec == "fp8xfp4")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
return run_mx_flatmm_with_layouts<ck_tile::fp8_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::fp16_t,
|
||||
MXFlatmm_GFX950_FP8FP4_Traits,
|
||||
false>(arg_parser, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only support non-persistent kernel now!");
|
||||
}
|
||||
else if(mx_prec == "fp4xfp8")
|
||||
{
|
||||
if(persistent_opt == 0)
|
||||
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp16_t,
|
||||
MXFlatmm_GFX950_FP4FP8_Traits,
|
||||
false>(arg_parser, Row{}, Col{}, Row{});
|
||||
else
|
||||
throw std::runtime_error("Only support non-persistent kernel now!");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data_type!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return run_mx_flatmm_example(arg_parser);
|
||||
}
|
||||
else if(warp_tile == 1)
|
||||
{
|
||||
throw std::runtime_error("Only support MFMA_16x16x128 now!");
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported warp_tile!");
|
||||
}
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
34
example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp
Normal file
34
example/ck_tile/18_flatmm/mxgemm/mx_flatmm.hpp
Normal file
@@ -0,0 +1,34 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include "mx_flatmm_arch_traits.hpp"
|
||||
|
||||
template <typename MXFlatmmArchTraits,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise,
|
||||
bool Splitk,
|
||||
bool HasHotLoop,
|
||||
ck_tile::TailNumber TailNum>
|
||||
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s);
|
||||
178
example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp
Normal file
178
example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp
Normal file
@@ -0,0 +1,178 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace core {
|
||||
namespace arch {
|
||||
|
||||
// Use the amdgcn_target_id enum from arch.hpp
|
||||
using TargetId = amdgcn_target_id;
|
||||
|
||||
} // namespace arch
|
||||
} // namespace core
|
||||
} // namespace ck_tile
|
||||
|
||||
// Base FlatmmConfig with 16x16 warp tile (for non-GFX1250)
|
||||
struct MXFlatmmConfigBase16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct MXfp4_FlatmmConfig16 : public MXFlatmmConfigBase16
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 512;
|
||||
};
|
||||
|
||||
// Architecture traits for MX Flatmm - Primary template (gfx950 implementation)
|
||||
template <ck_tile::core::arch::TargetId Arch, typename FlatmmConfig>
|
||||
struct MXFlatmmArchTraits
|
||||
{
|
||||
static constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
using Config = FlatmmConfig;
|
||||
|
||||
template <typename MXPipelineProblem>
|
||||
using MXFlatmmPipeline = ck_tile::MXFlatmmPipelineAGmemBGmemCRegV1<MXPipelineProblem>;
|
||||
|
||||
static constexpr int GetNLane() { return Config::N_Warp_Tile; }
|
||||
|
||||
template <typename dtype>
|
||||
static auto preShuffleWeight(ck_tile::HostTensor<dtype>& src)
|
||||
{
|
||||
constexpr ck_tile::index_t NLane = Config::N_Warp_Tile;
|
||||
auto src_lengths = src.get_lengths();
|
||||
const int K = src_lengths[0];
|
||||
const int N = src_lengths[1];
|
||||
constexpr int packed_size = ck_tile::numeric_traits<dtype>::PackedSize;
|
||||
int KPack = std::is_same_v<dtype, ck_tile::pk_fp6x16_t>
|
||||
? 32
|
||||
: 16 * packed_size; // fp4/fp6:32 or fp8:16
|
||||
|
||||
int KLane = ck_tile::get_warp_size() / NLane;
|
||||
int K0 = K / (KLane * KPack);
|
||||
|
||||
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({N * K}, {1}));
|
||||
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; k += packed_size)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
int tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
shuffled(outputIndex) = src(k, n);
|
||||
}
|
||||
}
|
||||
return shuffled;
|
||||
}
|
||||
|
||||
template <bool KLast, typename dtype>
|
||||
static auto preShuffleScale(ck_tile::HostTensor<dtype>& src)
|
||||
{
|
||||
auto src_lengths = src.get_lengths();
|
||||
const auto MN = KLast ? src_lengths[0] : src_lengths[1];
|
||||
const auto K = KLast ? src_lengths[1] : src_lengths[0];
|
||||
|
||||
size_t MNXdlPack = 2;
|
||||
size_t KXdlPack = 2;
|
||||
size_t XdlMNThread = Config::N_Warp_Tile; // 16
|
||||
size_t XdlKThread = ck_tile::get_warp_size() / XdlMNThread;
|
||||
|
||||
const auto MN_Paded = ck_tile::integer_least_multiple(MN, XdlMNThread * MNXdlPack);
|
||||
|
||||
ck_tile::HostTensor<dtype> shuffled(ck_tile::HostTensorDescriptor({MN_Paded * K}, {1}));
|
||||
|
||||
size_t K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(size_t n = 0; n < MN_Paded; ++n)
|
||||
{
|
||||
for(size_t k = 0; k < K; ++k)
|
||||
{
|
||||
auto n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
auto tempn = n % (XdlMNThread * MNXdlPack);
|
||||
auto n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
auto n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
auto k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
auto tempk = k % (XdlKThread * KXdlPack);
|
||||
auto k1 = tempk % XdlKThread; // i XdlKThread
|
||||
auto k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
auto outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread +
|
||||
n1 * MNXdlPack * KXdlPack + k2 * MNXdlPack + n2;
|
||||
|
||||
if constexpr(KLast)
|
||||
shuffled(outputIndex) = n < MN ? src(n, k) : dtype{};
|
||||
else
|
||||
shuffled(outputIndex) = n < MN ? src(k, n) : dtype{};
|
||||
}
|
||||
}
|
||||
return shuffled;
|
||||
}
|
||||
};
|
||||
|
||||
using MXFlatmm_GFX950_FP4FP4_Traits =
|
||||
MXFlatmmArchTraits<ck_tile::core::arch::TargetId::GFX950, MXfp4_FlatmmConfig16>;
|
||||
using MXFlatmm_GFX950_FP8FP8_Traits =
|
||||
MXFlatmmArchTraits<ck_tile::core::arch::TargetId::GFX950, MXFlatmmConfigBase16>;
|
||||
using MXFlatmm_GFX950_FP6FP6_Traits =
|
||||
MXFlatmmArchTraits<ck_tile::core::arch::TargetId::GFX950, MXFlatmmConfigBase16>;
|
||||
using MXFlatmm_GFX950_FP8FP4_Traits =
|
||||
MXFlatmmArchTraits<ck_tile::core::arch::TargetId::GFX950, MXFlatmmConfigBase16>;
|
||||
using MXFlatmm_GFX950_FP4FP8_Traits =
|
||||
MXFlatmmArchTraits<ck_tile::core::arch::TargetId::GFX950, MXFlatmmConfigBase16>;
|
||||
42
example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake
Normal file
42
example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cmake
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
function(mx_flatmm_instance_generate FILE_LIST)
|
||||
set(C_DATA_TYPE FP16)
|
||||
set(A_LAYOUT ROW)
|
||||
set(B_LAYOUT COL)
|
||||
set(C_LAYOUT ROW)
|
||||
|
||||
set(MXFLATMM_ARCH)
|
||||
|
||||
if (GPU_TARGETS MATCHES "gfx95")
|
||||
list(APPEND MXFLATMM_ARCH MXFlatmm_GFX950_)
|
||||
endif()
|
||||
|
||||
# foreach(PERSISTENT false true)
|
||||
# TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions.
|
||||
foreach(PERSISTENT false)
|
||||
foreach(DATA_TYPE FP4xFP4 FP8xFP8 FP6xFP6 FP8xFP4 FP4xFP8)
|
||||
string(REPLACE "x" ";" DATA_TYPE_AB ${DATA_TYPE})
|
||||
list(GET DATA_TYPE_AB 0 A_DATA_TYPE)
|
||||
list(GET DATA_TYPE_AB 1 B_DATA_TYPE)
|
||||
foreach(ARCH ${MXFLATMM_ARCH})
|
||||
set(MXFLATMM_ARCH_TRAITS "${ARCH}${A_DATA_TYPE}${B_DATA_TYPE}_Traits")
|
||||
foreach(SPLIT_K false true)
|
||||
foreach(HAS_HOT_LOOP false true)
|
||||
foreach(TAIL_NUMBER ODD EVEN)
|
||||
set(KERNEL_FILE mxgemm/instance_${ARCH}${DATA_TYPE}_${PERSISTENT}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp)
|
||||
string(TOLOWER ${KERNEL_FILE} KERNEL_FILE)
|
||||
configure_file(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in
|
||||
${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}
|
||||
@ONLY)
|
||||
list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE})
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endforeach()
|
||||
set(${FILE_LIST} ${${FILE_LIST}} PARENT_SCOPE)
|
||||
endfunction()
|
||||
57
example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in
Normal file
57
example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.cpp.in
Normal file
@@ -0,0 +1,57 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "mx_flatmm_instance.hpp"
|
||||
|
||||
// clang-format off
|
||||
#define MXFLATMM_ARCH_TRAITS @MXFLATMM_ARCH_TRAITS@
|
||||
#define A_DATA_TYPE @A_DATA_TYPE@
|
||||
#define B_DATA_TYPE @B_DATA_TYPE@
|
||||
#define C_DATA_TYPE @C_DATA_TYPE@
|
||||
#define A_LAYOUT @A_LAYOUT@
|
||||
#define B_LAYOUT @B_LAYOUT@
|
||||
#define C_LAYOUT @C_LAYOUT@
|
||||
#define PERSISTENT @PERSISTENT@
|
||||
#define SPLIT_K @SPLIT_K@
|
||||
#define HAS_HOT_LOOP @HAS_HOT_LOOP@
|
||||
#define TAIL_NUMBER @TAIL_NUMBER@
|
||||
// clang-format on
|
||||
|
||||
using FP4 = ck_tile::pk_fp4_t;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using FP6 = ck_tile::pk_fp6x16_t;
|
||||
using FP16 = ck_tile::fp16_t;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
|
||||
using ROW = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using COL = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
|
||||
inline constexpr auto ODD = ck_tile::TailNumber::Odd;
|
||||
inline constexpr auto EVEN = ck_tile::TailNumber::Even;
|
||||
|
||||
inline constexpr int ScaleGranularityM = 1;
|
||||
inline constexpr int ScaleGranularityN = 1;
|
||||
inline constexpr int ScaleGranularityK = 32;
|
||||
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>;
|
||||
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>;
|
||||
|
||||
template float mx_flatmm_calc<MXFLATMM_ARCH_TRAITS,
|
||||
A_DATA_TYPE,
|
||||
B_DATA_TYPE,
|
||||
/*DsDatatype*/ ck_tile::tuple<>,
|
||||
/*AccDataType*/ float,
|
||||
C_DATA_TYPE,
|
||||
A_LAYOUT,
|
||||
B_LAYOUT,
|
||||
/*DsLayout*/ ck_tile::tuple<>,
|
||||
C_LAYOUT,
|
||||
ScaleM,
|
||||
ScaleN,
|
||||
PERSISTENT,
|
||||
/*CDEElementWise*/ ck_tile::element_wise::PassThrough,
|
||||
SPLIT_K,
|
||||
HAS_HOT_LOOP,
|
||||
TAIL_NUMBER>(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s);
|
||||
174
example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp
Normal file
174
example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp
Normal file
@@ -0,0 +1,174 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "mx_flatmm.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
using is_row_major_t = ck_tile::bool_constant<
|
||||
std::is_same_v<ck_tile::remove_cvref_t<Layout>, ck_tile::tensor_layout::gemm::RowMajor>>;
|
||||
|
||||
template <typename MXFlatmmArchTraits,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
bool persistent,
|
||||
typename CDEElementWise,
|
||||
bool Splitk,
|
||||
bool HasHotLoop,
|
||||
ck_tile::TailNumber TailNum>
|
||||
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using FlatmmConfig = typename MXFlatmmArchTraits::Config;
|
||||
|
||||
using FlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using MXGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
persistent,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>;
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
|
||||
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
|
||||
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
ck_tile::ignore = Splitk;
|
||||
|
||||
// determined by scale shuffle pattern
|
||||
constexpr int BlockedXDLN_PerWarp = MXFlatmmArchTraits::BlockedXDLN_PerWarp;
|
||||
|
||||
using MXPipelineProblem = ck_tile::MXFlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
FlatmmShape,
|
||||
MXGemmTraits,
|
||||
scheduler,
|
||||
HasHotLoop,
|
||||
TailNum>;
|
||||
|
||||
using MXFlatmmPipeline =
|
||||
typename MXFlatmmArchTraits::template MXFlatmmPipeline<MXPipelineProblem>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<FlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
using GemmEpilogue =
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
MXPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
|
||||
using Kernel = ck_tile::MXFlatmmKernel<TilePartitioner, MXFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << FlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << FlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << MXPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << MXFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major_t<ALayout>{}));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major_t<BLayout>{}));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
182
example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc
Normal file
182
example/ck_tile/18_flatmm/mxgemm/run_mx_flatmm.inc
Normal file
@@ -0,0 +1,182 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
template <typename PrecActType,
|
||||
typename PrecWeightType,
|
||||
typename CDataType,
|
||||
typename MXFlatmmArchTraits,
|
||||
bool UsePersistentKernel = false,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_mx_flatmm_with_layouts(const ck_tile::ArgParser& arg_parser,
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const CLayout c_layout = CLayout{})
|
||||
{
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using AccDataType = float;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
|
||||
constexpr int ScaleGranularityM = 1;
|
||||
constexpr int ScaleGranularityN = 1;
|
||||
constexpr int ScaleGranularityK = 32;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
ck_tile::index_t n_warmup = arg_parser.get_int("warmup");
|
||||
ck_tile::index_t n_repeat = arg_parser.get_int("repeat");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout));
|
||||
|
||||
auto scale_stride_A = ck_tile::get_default_stride(
|
||||
M / ScaleGranularityM, K / ScaleGranularityK, 0, is_row_major(a_layout));
|
||||
auto scale_stride_B = ck_tile::get_default_stride(
|
||||
K / ScaleGranularityK, N / ScaleGranularityN, 0, is_row_major(b_layout));
|
||||
|
||||
if(K % ScaleGranularityK != 0)
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleGranularityK.");
|
||||
if(K % ck_tile::numeric_traits<ADataType>::PackedSize != 0 ||
|
||||
K % ck_tile::numeric_traits<BDataType>::PackedSize != 0)
|
||||
throw std::runtime_error("wrong! K must be multiple of packed size.");
|
||||
|
||||
ck_tile ::HostTensor<ADataType> a_host(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_origin_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_rslt_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_a(ck_tile::host_tensor_descriptor(
|
||||
M / ScaleGranularityM, K / ScaleGranularityK, scale_stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::host_tensor_descriptor(
|
||||
K / ScaleGranularityK, N / ScaleGranularityN, scale_stride_B, is_row_major(b_layout)));
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_fp6x16_t>)
|
||||
{
|
||||
auto a_buffer_bytes = a_host.get_element_space_size_in_bytes();
|
||||
auto b_buffer_bytes = b_origin_host.get_element_space_size_in_bytes();
|
||||
ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<>{-1.f, 1.f}(scale_b);
|
||||
std::vector<int8_t> random_bufA(a_buffer_bytes);
|
||||
std::vector<int8_t> random_bufB(b_buffer_bytes);
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<int> dis(1, 4);
|
||||
|
||||
for(size_t i = 0; i < a_buffer_bytes; ++i)
|
||||
random_bufA[i] = static_cast<int8_t>(dis(gen));
|
||||
|
||||
for(size_t i = 0; i < b_buffer_bytes; ++i)
|
||||
random_bufB[i] = static_cast<int8_t>(dis(gen));
|
||||
|
||||
memcpy(a_host.data(), random_bufA.data(), a_buffer_bytes);
|
||||
memcpy(b_origin_host.data(), random_bufB.data(), b_buffer_bytes);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<>{-2.f, 2.f}(scale_b);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<>{1.f, 1.f}(scale_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! Unexpected init_method");
|
||||
}
|
||||
}
|
||||
|
||||
const auto b_shuffled_host = MXFlatmmArchTraits::preShuffleWeight(b_origin_host);
|
||||
const auto scale_a_shuffled = MXFlatmmArchTraits::template preShuffleScale<true>(scale_a);
|
||||
const auto scale_b_shuffled = MXFlatmmArchTraits::template preShuffleScale<false>(scale_b);
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_shuffled_dev_buf(b_shuffled_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a_shuffled.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b_shuffled.get_element_space_size_in_bytes());
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
b_shuffled_dev_buf.ToDevice(b_shuffled_host.data());
|
||||
c_rslt_host.SetZero();
|
||||
scale_a_dev_buf.ToDevice(scale_a_shuffled.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b_shuffled.data());
|
||||
|
||||
auto scale_a_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM};
|
||||
auto scale_b_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>{
|
||||
static_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
|
||||
|
||||
invoke_mx_flatmm<MXFlatmmArchTraits,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
decltype(scale_a_dev_ptr),
|
||||
decltype(scale_b_dev_ptr),
|
||||
UsePersistentKernel>(a_dev_buf,
|
||||
b_shuffled_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
scale_a_dev_ptr,
|
||||
scale_b_dev_ptr,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_dev_buf.FromDevice(c_rslt_host.data());
|
||||
|
||||
bool pass = true;
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_mx_gemm<ADataType, BDataType, ScaleType, AccDataType, CDataType>(
|
||||
a_host, b_origin_host, c_m_n_host_ref, scale_a, scale_b);
|
||||
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
|
||||
const float atol = std::is_same_v<ADataType, ck_tile::half_t> ? 1e-3 : 1e-2;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_rslt_host, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass ? 0 : -1;
|
||||
}
|
||||
261
example/ck_tile/18_flatmm/run_flatmm_example.inc
Normal file
261
example/ck_tile/18_flatmm/run_flatmm_example.inc
Normal file
@@ -0,0 +1,261 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
template <typename PrecType,
|
||||
typename FlatmmConfig,
|
||||
int ScaleGranularityM = -1,
|
||||
int ScaleGranularityN = -1,
|
||||
bool UsePersistentKernel = false,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_flatmm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
// persistent not added
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_host(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_origin_host(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_rslt_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::HostTensor<AccDataType> per_token_scale(ck_tile::HostTensorDescriptor({M}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> per_channel_scale(ck_tile::HostTensorDescriptor({N}, {1}));
|
||||
|
||||
// TODO: add different init types
|
||||
if(init_method == 0)
|
||||
{
|
||||
// ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||
// ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_host);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_channel_scale);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(per_channel_scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_host.SetZero();
|
||||
b_origin_host.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem per_channel_scale_dev_buf(
|
||||
per_channel_scale.get_element_space_size_in_bytes());
|
||||
|
||||
a_dev_buf.ToDevice(a_host.data());
|
||||
c_rslt_host.SetZero();
|
||||
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
|
||||
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
|
||||
|
||||
// do pre-shuffle
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
|
||||
if constexpr(FlatmmConfig::TiledMMAPermuteN)
|
||||
{
|
||||
return shuffle_b_v1<FlatmmConfig>(b_origin_host);
|
||||
}
|
||||
else
|
||||
{
|
||||
return shuffle_b_v0<FlatmmConfig>(b_origin_host);
|
||||
}
|
||||
}();
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
|
||||
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
|
||||
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
|
||||
|
||||
invoke_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
decltype(per_token_scale_dev_ptr),
|
||||
decltype(per_channel_scale_dev_ptr),
|
||||
UsePersistentKernel>(a_dev_buf,
|
||||
b_shuffle_dev_buf,
|
||||
c_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
per_token_scale_dev_ptr,
|
||||
per_channel_scale_dev_ptr,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_dev_buf.FromDevice(c_rslt_host.data());
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
if(ScaleGranularityM != -1 || ScaleGranularityN != -1)
|
||||
throw std::runtime_error("ScaleAB is not supported for CPU verification!\n");
|
||||
ck_tile::HostTensor<CDataType> c_ref_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_ref_host.SetZero();
|
||||
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_host, b_origin_host, c_ref_host);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_ref_host.mData.begin(), c_ref_host.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_rslt_host,
|
||||
c_ref_host,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU veification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
ck_tile::DeviceMem b_origin_dev_buf(b_origin_host.get_element_space_size_in_bytes());
|
||||
b_origin_dev_buf.ToDevice(b_origin_host.data());
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
ck_tile::DeviceMem c_gpu_ref_dev_buf(c_gpu_ref_host.get_element_space_size_in_bytes());
|
||||
c_gpu_ref_host.SetZero();
|
||||
c_gpu_ref_dev_buf.SetZero();
|
||||
|
||||
ADataType* d_A;
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
|
||||
ck_tile::hip_check_error(hipMalloc(&d_A, M * K * sizeof(ADataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(
|
||||
d_A, a_dev_buf.GetDeviceBuffer(), M * K * sizeof(ADataType), hipMemcpyHostToDevice));
|
||||
ck_tile::hip_check_error(hipMemcpy(d_B,
|
||||
b_origin_dev_buf.GetDeviceBuffer(),
|
||||
N * K * sizeof(BDataType),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
if constexpr(ScaleGranularityM == -1 && ScaleGranularityN == -1)
|
||||
{
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_blockwise_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
d_A,
|
||||
d_B,
|
||||
d_C,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
ScaleGranularityM,
|
||||
ScaleGranularityN,
|
||||
K,
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_dev_buf.GetDeviceBuffer(),
|
||||
d_C,
|
||||
M * N * sizeof(CDataType),
|
||||
hipMemcpyDeviceToHost));
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_A));
|
||||
ck_tile::hip_check_error(hipFree(d_B));
|
||||
ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
c_gpu_ref_dev_buf.FromDevice(c_gpu_ref_host.data());
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_gpu_ref_host.mData.begin(), c_gpu_ref_host.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_rslt_host,
|
||||
c_gpu_ref_host,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
588
example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc
Normal file
588
example/ck_tile/18_flatmm/run_grouped_flatmm_example.inc
Normal file
@@ -0,0 +1,588 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
const ck_tile::ContiguousGroupedFlatmmHostArgs<ScaleM, ScaleN>& args)
|
||||
{
|
||||
float ave_time = grouped_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
false,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Grouped Gemm"};
|
||||
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
|
||||
sizeof(BDataType) * args.N * args.K +
|
||||
sizeof(CDataType) * args.M * args.N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ScaleM,
|
||||
typename ScaleN,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
int val_m,
|
||||
const ck_tile::MaskedGroupedFlatmmHostArgs<ScaleM, ScaleN>& args)
|
||||
{
|
||||
float ave_time = grouped_flatmm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
false,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Grouped Gemm"};
|
||||
|
||||
std::size_t flop = std::size_t(2) * val_m * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * val_m * args.K +
|
||||
sizeof(BDataType) * args.N * args.K * args.group_count +
|
||||
sizeof(CDataType) * val_m * args.N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename PrecType,
|
||||
typename FlatmmConfig,
|
||||
int ScaleGranularityM = -1,
|
||||
int ScaleGranularityN = -1,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_contiguous_grouped_flatmm_example_with_layouts(
|
||||
int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
constexpr int BlockM = FlatmmConfig::M_Tile;
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
|
||||
if(!(int(Ms.size()) == group_count))
|
||||
{
|
||||
std::cout << "Please check the input data." << std::endl;
|
||||
// padding additional Ms if needed
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 64 * i);
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::index_t M =
|
||||
std::accumulate(Ms.begin(), Ms.begin() + group_count, 0, [](auto acc, auto group_m) {
|
||||
// round up to the multiple of BlockM
|
||||
return acc + (group_m + BlockM - 1) / BlockM * BlockM;
|
||||
});
|
||||
std::cout << "Total M: " << M << std::endl;
|
||||
ck_tile::index_t N = Ns[0];
|
||||
ck_tile::index_t K = Ks[0];
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
|
||||
ck_tile::index_t stride_A = 0;
|
||||
ck_tile::index_t stride_B = 0;
|
||||
ck_tile::index_t stride_C = 0;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N * group_count, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k_tensor(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n_tensor(ck_tile::HostTensor<BDataType>(
|
||||
ck_tile::host_tensor_descriptor(K, N * group_count, stride_B, is_row_major(b_layout))));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_tensor(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(c_layout))));
|
||||
|
||||
ck_tile::HostTensor<AccDataType> per_token_scale(ck_tile::HostTensorDescriptor({M}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> per_channel_scale(ck_tile::HostTensorDescriptor({N}, {1}));
|
||||
|
||||
std::vector<ck_tile::index_t> m_indices(M);
|
||||
int indices_fill_start = 0;
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
int group_m = Ms[i];
|
||||
int padded_group_m = (group_m + BlockM - 1) / BlockM * BlockM;
|
||||
for(int j = 0; j < padded_group_m; j++)
|
||||
{
|
||||
m_indices[indices_fill_start + j] = j < group_m ? i : -1; // -1 for padding
|
||||
}
|
||||
indices_fill_start += padded_group_m;
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
|
||||
|
||||
assert(N % N_Warp_Tile == 0 &&
|
||||
"N must be divisible by N_Warp_Tile for contiguous grouped gemm");
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host =
|
||||
ck_tile::shuffle_b<FlatmmConfig, BDataType>(b_k_n_tensor);
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> a_m_k_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(a_m_k_tensor.get_element_space_size_in_bytes()));
|
||||
std::unique_ptr<ck_tile::DeviceMem> b_shfl_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(b_shuffle_host.get_element_space_size_in_bytes()));
|
||||
std::unique_ptr<ck_tile::DeviceMem> c_m_n_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes()));
|
||||
|
||||
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem per_channel_scale_dev_buf(
|
||||
per_channel_scale.get_element_space_size_in_bytes());
|
||||
|
||||
c_m_n_dev_buf->SetZero();
|
||||
|
||||
ck_tile::DeviceMem m_indices_dev_buf(M * sizeof(ck_tile::index_t));
|
||||
m_indices_dev_buf.ToDevice(m_indices.data());
|
||||
|
||||
a_m_k_dev_buf->ToDevice(a_m_k_tensor.data());
|
||||
b_shfl_dev_buf->ToDevice(b_shuffle_host.data());
|
||||
|
||||
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
|
||||
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
|
||||
|
||||
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
|
||||
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
|
||||
|
||||
ck_tile::ContiguousGroupedFlatmmHostArgs<decltype(per_token_scale_dev_ptr),
|
||||
decltype(per_channel_scale_dev_ptr)>
|
||||
kernal_args{static_cast<ck_tile::index_t*>(m_indices_dev_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a_m_k_dev_buf->GetDeviceBuffer(),
|
||||
stride_A,
|
||||
b_shfl_dev_buf->GetDeviceBuffer(),
|
||||
stride_B,
|
||||
{},
|
||||
{},
|
||||
c_m_n_dev_buf->GetDeviceBuffer(),
|
||||
stride_C,
|
||||
kbatch,
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
|
||||
|
||||
invoke_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
decltype(per_token_scale_dev_ptr),
|
||||
decltype(per_channel_scale_dev_ptr)>(warmup, repeat, kernal_args);
|
||||
c_m_n_dev_buf->FromDevice(c_m_n_tensor.data());
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Not support v=1 host verification in contiguous grouped gemm, use "
|
||||
"v=2 device verification instead");
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_C, M * N * sizeof(CDataType)));
|
||||
ck_tile::hip_check_error(hipMemset(d_C, 0, M * N * sizeof(CDataType)));
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::index_t acc_m = 0;
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::index_t padded_M = (Ms[i] + BlockM - 1) / BlockM * BlockM;
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpy(d_B,
|
||||
b_k_n_tensor.data() + i * N * K,
|
||||
N * K * sizeof(BDataType),
|
||||
hipMemcpyHostToDevice));
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
static_cast<ADataType*>(a_m_k_dev_buf->GetDeviceBuffer()) + acc_m * K,
|
||||
d_B,
|
||||
d_C + acc_m * N,
|
||||
padded_M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C);
|
||||
acc_m += padded_M;
|
||||
}
|
||||
ck_tile::hip_check_error(hipMemcpy(
|
||||
c_gpu_ref_host.data(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost));
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_B));
|
||||
ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
float rtol = 1e-3;
|
||||
float atol = 1e-3;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_m_n_tensor, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename PrecType,
|
||||
typename FlatmmConfig,
|
||||
int ScaleGranularityM = -1,
|
||||
int ScaleGranularityN = -1,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_masked_grouped_flatmm_example_with_layouts(
|
||||
int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
constexpr int BlockM = FlatmmConfig::M_Tile;
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
|
||||
if(!(int(Ms.size()) == group_count))
|
||||
{
|
||||
std::cout << "Please check the input data." << std::endl;
|
||||
// padding additional Ms if needed
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 64 * i);
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::index_t M = 4096; // Ms[0];
|
||||
ck_tile::index_t N = Ns[0];
|
||||
ck_tile::index_t K = Ks[0];
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
|
||||
ck_tile::index_t stride_A = K;
|
||||
ck_tile::index_t stride_B = K;
|
||||
ck_tile::index_t stride_C = N;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(group_count * M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N * group_count, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(group_count * M, N, stride_C, is_row_major(c_layout));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k_tensor(
|
||||
ck_tile::host_tensor_descriptor(group_count * M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n_tensor(ck_tile::HostTensor<BDataType>(
|
||||
ck_tile::host_tensor_descriptor(K, N * group_count, stride_B, is_row_major(b_layout))));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_tensor(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(group_count * M, N, stride_C, is_row_major(c_layout))));
|
||||
|
||||
ck_tile::HostTensor<AccDataType> per_token_scale(
|
||||
ck_tile::HostTensorDescriptor({group_count * M}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> per_channel_scale(
|
||||
ck_tile::HostTensorDescriptor({group_count * N}, {1}));
|
||||
|
||||
std::vector<ck_tile::index_t> m_indices(group_count);
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
int group_m = Ms[i];
|
||||
int padded_group_m = (group_m + BlockM - 1) / BlockM * BlockM;
|
||||
for(int j = 0; j < padded_group_m; j++)
|
||||
{
|
||||
m_indices[i] = group_m;
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{-1.f, 1.f}(per_channel_scale);
|
||||
|
||||
assert(N % N_Warp_Tile == 0 &&
|
||||
"N must be divisible by N_Warp_Tile for contiguous grouped gemm");
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host =
|
||||
ck_tile::shuffle_b<FlatmmConfig, BDataType>(b_k_n_tensor);
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> a_m_k_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(a_m_k_tensor.get_element_space_size_in_bytes()));
|
||||
std::unique_ptr<ck_tile::DeviceMem> b_shfl_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(b_shuffle_host.get_element_space_size_in_bytes()));
|
||||
std::unique_ptr<ck_tile::DeviceMem> c_m_n_dev_buf(
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes()));
|
||||
|
||||
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem per_channel_scale_dev_buf(
|
||||
per_channel_scale.get_element_space_size_in_bytes());
|
||||
c_m_n_dev_buf->SetZero();
|
||||
|
||||
ck_tile::DeviceMem m_indices_dev_buf(group_count * sizeof(ck_tile::index_t));
|
||||
m_indices_dev_buf.ToDevice(m_indices.data());
|
||||
|
||||
a_m_k_dev_buf->ToDevice(a_m_k_tensor.data());
|
||||
b_shfl_dev_buf->ToDevice(b_shuffle_host.data());
|
||||
|
||||
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
|
||||
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
|
||||
|
||||
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
|
||||
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
|
||||
ck_tile::MaskedGroupedFlatmmHostArgs<decltype(per_token_scale_dev_ptr),
|
||||
decltype(per_channel_scale_dev_ptr)>
|
||||
kernal_args{static_cast<ck_tile::index_t*>(m_indices_dev_buf.GetDeviceBuffer()),
|
||||
group_count,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a_m_k_dev_buf->GetDeviceBuffer(),
|
||||
stride_A,
|
||||
b_shfl_dev_buf->GetDeviceBuffer(),
|
||||
stride_B,
|
||||
{},
|
||||
{},
|
||||
c_m_n_dev_buf->GetDeviceBuffer(),
|
||||
stride_C,
|
||||
kbatch,
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
|
||||
int sum_val_m = 0;
|
||||
for(int gi = 0; gi < group_count; gi++)
|
||||
{
|
||||
sum_val_m += m_indices[gi];
|
||||
}
|
||||
|
||||
invoke_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
decltype(per_token_scale_dev_ptr),
|
||||
decltype(per_channel_scale_dev_ptr)>(warmup, repeat, sum_val_m, kernal_args);
|
||||
c_m_n_dev_buf->FromDevice(c_m_n_tensor.data());
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Not support v=1 host verification in contiguous grouped gemm, use "
|
||||
"v=2 device verification instead");
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
BDataType* d_B;
|
||||
CDataType* d_C;
|
||||
ck_tile::hip_check_error(hipMalloc(&d_B, N * K * sizeof(BDataType)));
|
||||
ck_tile::hip_check_error(hipMalloc(&d_C, group_count * M * N * sizeof(CDataType)));
|
||||
ck_tile::hip_check_error(hipMemset(d_C, 0, group_count * M * N * sizeof(CDataType)));
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_gpu_ref_host(
|
||||
ck_tile::host_tensor_descriptor(group_count * M, N, stride_C, is_row_major(CLayout{})));
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::hip_check_error(hipMemcpy(d_B,
|
||||
b_k_n_tensor.data() + i * N * K,
|
||||
N * K * sizeof(BDataType),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
if constexpr(ScaleGranularityM == -1 && ScaleGranularityN == -1)
|
||||
{
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
static_cast<ADataType*>(a_m_k_dev_buf->GetDeviceBuffer()) + i * M * K,
|
||||
d_B,
|
||||
d_C + i * M * N,
|
||||
m_indices[i],
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_blockwise_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
static_cast<ADataType*>(a_m_k_dev_buf->GetDeviceBuffer()) + i * M * K,
|
||||
d_B,
|
||||
d_C + i * M * N,
|
||||
m_indices[i],
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
ScaleGranularityM,
|
||||
ScaleGranularityN,
|
||||
K,
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()) + i * M,
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())) +
|
||||
i* N;
|
||||
}
|
||||
ck_tile::hip_check_error(hipMemcpy(c_gpu_ref_host.data() + i * M * N,
|
||||
d_C + i * M * N,
|
||||
M * N * sizeof(CDataType),
|
||||
hipMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
ck_tile::hip_check_error(hipFree(d_B));
|
||||
ck_tile::hip_check_error(hipFree(d_C));
|
||||
|
||||
float rtol = 1e-3;
|
||||
float atol = 1e-3;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_m_n_tensor, c_gpu_ref_host, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
319
example/ck_tile/18_flatmm/run_moe_flatmm_example.inc
Normal file
319
example/ck_tile/18_flatmm/run_moe_flatmm_example.inc
Normal file
@@ -0,0 +1,319 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ck_tile::MoeFlatmmKind kind,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename MoeHostArgs>
|
||||
float invoke_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)
|
||||
{
|
||||
float ave_time = moe_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
kind,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::string op_name{"Moe Gemm"};
|
||||
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
|
||||
sizeof(BDataType) * args.N * args.K +
|
||||
sizeof(CDataType) * args.M * args.N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename PrecType,
|
||||
typename FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind kind,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_moe_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
constexpr int ScaleGranularityM = 1;
|
||||
constexpr int ScaleGranularityN = 1;
|
||||
|
||||
const ck_tile::index_t N = arg_parser.get_int("N");
|
||||
const ck_tile::index_t K = arg_parser.get_int("K");
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_A");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_B");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_C");
|
||||
const ck_tile::index_t num_tokens = arg_parser.get_int("NumTokens");
|
||||
const ck_tile::index_t topk = arg_parser.get_int("TopK");
|
||||
const ck_tile::index_t warmup = arg_parser.get_int("warmup");
|
||||
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
|
||||
const ck_tile::index_t experts = arg_parser.get_int("experts");
|
||||
|
||||
// TODO: replace the magic declaration
|
||||
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
|
||||
|
||||
ck_tile::index_t sorted_tile_num = (num_tokens + MPerBlock - 1) / MPerBlock * MPerBlock * topk;
|
||||
ck_tile::index_t valid_tile_num = sorted_tile_num;
|
||||
ck_tile::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
|
||||
const ck_tile::index_t M = sorted_tile_num * MPerBlock;
|
||||
const ck_tile::index_t outputN = kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? N / 2 : N;
|
||||
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr bool IsInputGemm = kind != ck_tile::MoeFlatmmKind::kFFN_gemm2;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(
|
||||
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(
|
||||
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
auto a_m_k_tensor = ck_tile::HostTensor<ADataType>(ck_tile::host_tensor_descriptor(
|
||||
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout)));
|
||||
auto b_k_n_tensor = ck_tile::HostTensor<BDataType>(
|
||||
is_row_major(b_layout)
|
||||
? ck_tile::host_tensor_descriptor(experts * N, K, stride_B, is_row_major(b_layout))
|
||||
: ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
|
||||
auto c_m_n_tensor = ck_tile::HostTensor<CDataType>(ck_tile::host_tensor_descriptor(
|
||||
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
|
||||
|
||||
auto b_shuffle_host = flatmm_shuffle_b<FlatmmConfig>(b_k_n_tensor);
|
||||
|
||||
std::cout << "moe_flatmm:" //
|
||||
<< "\n num_experts: " << experts << "\n num_tokens: " << num_tokens
|
||||
<< "\n topk: " << topk << "\n sorted_tile_num: " << sorted_tile_num
|
||||
<< "\n a_m_k: " << a_m_k_tensor.mDesc << "\n b_k_n: " << b_k_n_tensor.mDesc
|
||||
<< "\n b_shuffle: " << b_shuffle_host.mDesc << "\n c_m_n: " << c_m_n_tensor.mDesc
|
||||
<< std::endl;
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf{a_m_k_tensor.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem b_origin_dev_buf{b_k_n_tensor.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf{b_shuffle_host.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem c_m_n_dev_buf{c_m_n_tensor.get_element_space_size_in_bytes()};
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_tensor.data());
|
||||
b_origin_dev_buf.ToDevice(b_k_n_tensor.data());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_tensor.SetZero();
|
||||
|
||||
const void* p_a = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
const void* p_b_origin = b_origin_dev_buf.GetDeviceBuffer();
|
||||
const void* p_b_shuffle = b_shuffle_dev_buf.GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
|
||||
// TODO: malloc and init sorted tokens and max tokens buffer
|
||||
|
||||
ck_tile::HostTensor<ck_tile::index_t> expert_ids(
|
||||
ck_tile::HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> sorted_token_ids(
|
||||
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> expert_weight(
|
||||
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> max_token_id(
|
||||
ck_tile::HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
|
||||
ck_tile::HostTensor<AccDataType> per_token_scale(
|
||||
ck_tile::HostTensorDescriptor({IsInputGemm ? num_tokens : M}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> per_channel_scale(
|
||||
ck_tile::HostTensorDescriptor({N * experts}, {1}));
|
||||
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.f, 1.f}(per_channel_scale);
|
||||
|
||||
// for verification only, no need to satify weight normalization
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 1.0f}(expert_weight);
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_dev{sorted_token_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_ids_dev{expert_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem max_token_id_dev{max_token_id.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_weight_dev{expert_weight.get_element_space_size_in_bytes()};
|
||||
|
||||
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem per_channel_scale_dev_buf(
|
||||
per_channel_scale.get_element_space_size_in_bytes());
|
||||
|
||||
max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8};
|
||||
// int eids[] = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts);
|
||||
}
|
||||
|
||||
int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
// int token_per_tile = num_tokens * topk / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
// sorted_token_ids.mData[0] = 0;
|
||||
for(int i = 0; i < sorted_tile_num * MPerBlock; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile && tokenid < num_tokens * topk)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = num_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.data());
|
||||
expert_weight_dev.ToDevice(expert_weight.data());
|
||||
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
|
||||
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
|
||||
|
||||
const ck_tile::index_t* p_sorted_token_ids_dev =
|
||||
static_cast<ck_tile::index_t*>(sorted_token_ids_dev.GetDeviceBuffer());
|
||||
const ck_tile::index_t* p_expert_ids_dev =
|
||||
static_cast<ck_tile::index_t*>(expert_ids_dev.GetDeviceBuffer());
|
||||
const ck_tile::index_t* p_max_token_id_dev =
|
||||
static_cast<ck_tile::index_t*>(max_token_id_dev.GetDeviceBuffer());
|
||||
const AccDataType* p_sorted_expert_weight_dev =
|
||||
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
|
||||
|
||||
using MoeFlatmmArgs =
|
||||
ck_tile::MoeFlatmmHostArgs<ck_tile::FlatmmScalePointer<ScaleGranularityM>,
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN>>;
|
||||
|
||||
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
|
||||
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
|
||||
|
||||
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
|
||||
p_sorted_expert_weight_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
p_a,
|
||||
p_b_shuffle,
|
||||
p_c,
|
||||
num_tokens,
|
||||
experts,
|
||||
topk,
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
per_token_scale_dev_ptr,
|
||||
per_channel_scale_dev_ptr};
|
||||
|
||||
invoke_moe_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
kind>(warmup, repeat, gemm_desc);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_tensor.data());
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("validate"))
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(IsInputGemm ? num_tokens * topk : num_tokens,
|
||||
outputN,
|
||||
stride_C,
|
||||
is_row_major(CLayout{})));
|
||||
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> c_m_n_ref_buf =
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
|
||||
|
||||
c_m_n_ref_buf->SetZero();
|
||||
|
||||
ck_tile::reference_moe_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
static_cast<int>(kind),
|
||||
ck_tile::moe::MoeSilu>(
|
||||
p_sorted_token_ids_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b_origin),
|
||||
static_cast<CDataType*>(c_m_n_ref_buf->GetDeviceBuffer()),
|
||||
p_sorted_expert_weight_dev,
|
||||
num_tokens,
|
||||
MPerBlock,
|
||||
topk,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
1,
|
||||
1,
|
||||
K,
|
||||
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer()));
|
||||
|
||||
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
|
||||
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
const float atol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_m_n_tensor, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
38
example/ck_tile/18_flatmm/script/smoke_test_basic.sh
Executable file
38
example/ck_tile/18_flatmm/script/smoke_test_basic.sh
Executable file
@@ -0,0 +1,38 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
|
||||
EXE="$(find . -name tile_example_flatmm_basic -type f | head -n 1)"
|
||||
KNAME=1
|
||||
|
||||
export CK_WARMUP=0
|
||||
export CK_REPEAT=1
|
||||
|
||||
COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
|
||||
|
||||
run_tests() {
|
||||
for m in 128 1024; do
|
||||
for n in 128 2048; do
|
||||
for k in 128 4096; do
|
||||
|
||||
$EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Success: Test with m=$m, n=$n, k=$k executed successfully."
|
||||
else
|
||||
echo "Error: Test with m=$m, n=$n, k=$k failed to execute properly."
|
||||
# Optionally, exit or break if you need to halt further execution
|
||||
# exit 1
|
||||
fi
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
set -x
|
||||
|
||||
run_tests "bf16"
|
||||
run_tests "fp16"
|
||||
|
||||
set +x
|
||||
Reference in New Issue
Block a user