This commit is contained in:
Ding, Yi
2026-03-11 23:03:20 -04:00
commit e6cd3f1e3f
6330 changed files with 1132789 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9")
add_executable(tile_example_streamk_gemm_basic streamk_gemm_basic.cpp)
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
target_compile_options(tile_example_streamk_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile streamk gemm tests for current target")
endif()

View File

@@ -0,0 +1,37 @@
# Stream-K GEMM
This folder contains examples of Stream-K GEMMs using the ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
# Compile the Stream-K kernels
make tile_example_streamk_gemm_basic -j
```
This will result in an executable `build/bin/tile_example_streamk_gemm_basic`
## example
```
args:
-m m dimension (default:512)
-n n dimension (default:512)
-k k dimension (default:512)
-a_layout tensor A data layout (default: R)
-b_layout tensor B data layout (default: C)
-c_layout tensor C data layout (default: R)
-reduction_strategy strategy for storing results in C tensor. atomic/linear (default:atomic)
-persistent_dp persistent strategy for data-parallel section. Set to 0 for non-persistent or to 1 for persistent. (default:0)
-stride_a tensor A stride (default:0)
-stride_b tensor B stride (default:0)
-stride_c tensor C stride (default:0)
-v validation strategy. 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmarking the kernel (default:50)
-repeat number of iterations to benchmark the kernel (default:100)
-timer timing mode. gpu:gpu timer, cpu:cpu timer (default:gpu)
-init data initialization strategy. 0:random, 1:linear, 2:constant(1) (default:0)
-flush_cache flush the cache before running the kernel (default:true)
```

View File

@@ -0,0 +1,85 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
struct GemmConfigurationBase
{
static constexpr bool PAD_M = true;
static constexpr bool PAD_N = true;
static constexpr bool PAD_K = true;
static constexpr bool PERMUTE_A = false;
static constexpr bool PERMUTE_B = false;
static constexpr bool TRANSPOSE_C = false;
static constexpr bool USE_STRUCTURED_SPARSITY = false;
static constexpr int BLOCK_PER_CU = 1;
static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NUM_WAVE_GROUPS = 1;
static constexpr bool PRESHUFFLE = false;
static constexpr bool DOUBLE_SMEM_BUFFER = false;
};
template <typename PrecisionType, bool IsPersistent>
struct GemmConfigurationMemoryInterwave : public GemmConfigurationBase
{
static constexpr ck_tile::index_t M_TILE = 256;
static constexpr ck_tile::index_t N_TILE = 256;
static constexpr ck_tile::index_t K_TILE = 16;
static constexpr ck_tile::index_t M_WARP = 2;
static constexpr ck_tile::index_t N_WARP = 2;
static constexpr ck_tile::index_t K_WARP = 1;
static constexpr ck_tile::index_t M_WARP_TILE = 32;
static constexpr ck_tile::index_t N_WARP_TILE = 32;
static constexpr ck_tile::index_t K_WARP_TILE = sizeof(PrecisionType) == 2 ? 8 : 16;
static constexpr bool PERSISTENT = IsPersistent;
static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <typename ADataType_, typename BDataType_ = ADataType_, typename CDataType_ = ADataType_>
struct StreamKGemmTypeConfiguration
{
using ADataType = ADataType_;
using BDataType = BDataType_;
using AccDataType = float;
using CDataType = CDataType_;
};
auto createArgs(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "512", "m dimension")
.insert("n", "512", "n dimension")
.insert("k", "512", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Column by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("reduction_strategy",
"atomic",
"strategy for storing results in C tensor - atomic/linear")
.insert("persistent_dp",
"0",
"0. Non-persistent data-parallel section, 1 Fully persistent kernel.")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmarking the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}

View File

@@ -0,0 +1,334 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/common/utils.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout)
{
return ck_tile::bool_constant<
std::is_same_v<ck_tile::remove_cvref_t<Layout>, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculateRtolAtol(const ck_tile::index_t k_dim,
const ck_tile::index_t k_batch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto relative_tolerance =
ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(k_dim, k_batch));
const auto absolute_tolerance =
ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / k_batch, ck_tile::integer_divide_ceil(k_dim, k_batch));
// Calculate error due to multiple WGs working in the same C macro tile
const auto relative_tolerance_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(k_batch);
const auto absolute_tolerance_split_k =
ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(max_accumulated_value,
k_batch);
// Use higher threshold
return ck_tile::make_tuple(std::max(relative_tolerance, relative_tolerance_split_k),
std::max(absolute_tolerance, absolute_tolerance_split_k));
}
template <typename GemmConfiguration,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccumulatorDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& stream_config);
template <typename GemmConfiguration,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccumulatorDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
std::tuple<float, ck_tile::index_t> invokeGemm(ck_tile::DeviceMem& a_m_k_device_memory,
ck_tile::DeviceMem& b_k_n_device_memory,
ck_tile::DeviceMem& c_m_n_device_memory,
ck_tile::index_t m_dim,
ck_tile::index_t n_dim,
ck_tile::index_t k_dim,
ck_tile::index_t stride_a,
ck_tile::index_t stride_b,
ck_tile::index_t stride_c,
int warmup_iterations,
int repeat_iterations,
bool flush_cache,
ck_tile::StreamKReductionStrategy reduction_strategy)
{
ck_tile::StreamKHostArgs args{a_m_k_device_memory.GetDeviceBuffer(),
b_k_n_device_memory.GetDeviceBuffer(),
c_m_n_device_memory.GetDeviceBuffer(),
m_dim,
n_dim,
k_dim,
stride_a,
stride_b,
stride_c};
std::tuple<float, ck_tile::index_t> average_time_and_batch;
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
{
average_time_and_batch = gemm<GemmConfiguration,
ADataType,
BDataType,
DsDataType,
AccumulatorDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
CDEElementWise,
ck_tile::StreamKReductionStrategy::Atomic>(
args,
ck_tile::stream_config{
nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache});
}
else /*Reduction*/
{
average_time_and_batch = gemm<GemmConfiguration,
ADataType,
BDataType,
DsDataType,
AccumulatorDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
CDEElementWise,
ck_tile::StreamKReductionStrategy::Linear>(
args,
ck_tile::stream_config{
nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache});
}
return average_time_and_batch;
}
template <typename CDataType>
bool doVerify(const ck_tile::HostTensor<CDataType>& c_m_n_device_result,
const ck_tile::HostTensor<CDataType>& c_m_n_reference,
const ck_tile::tuple<double, double>& relative_absolute_tolerances,
const char* variant)
{
bool pass = ck_tile::check_err(c_m_n_device_result,
c_m_n_reference,
"Error: Incorrect results!",
relative_absolute_tolerances.at(ck_tile::number<0>{}),
relative_absolute_tolerances.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: "
<< relative_absolute_tolerances.at(ck_tile::number<0>{})
<< " Absolute error threshold: "
<< relative_absolute_tolerances.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The " << variant << " verification result is:" << (pass ? "correct" : "fail")
<< std::endl;
return pass;
}
ck_tile::StreamKReductionStrategy getReductionStrategyValue(const std::string& strategy)
{
if(strategy == "atomic")
{
return ck_tile::StreamKReductionStrategy::Atomic;
}
else if(strategy == "linear")
{
return ck_tile::StreamKReductionStrategy::Linear;
}
else
{
throw std::runtime_error("Unsupported Stream-K reduction strategy !!!");
}
}
template <typename GemmConfiguration,
typename TypeConfiguration,
typename ALayout,
typename BLayout,
typename CLayout>
int runGemmExampleWithLayouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = createArgs(argc, argv);
if(!result)
return -1;
static_assert(!GemmConfiguration::PRESHUFFLE, "Not implemented");
static_assert(!GemmConfiguration::USE_STRUCTURED_SPARSITY, "Not implemented");
static_assert(!GemmConfiguration::PERMUTE_A, "Not implemented");
static_assert(!GemmConfiguration::PERMUTE_B, "Not implemented");
using ADataType = typename TypeConfiguration::ADataType;
using BDataType = typename TypeConfiguration::BDataType;
using AccumulatorDataType = typename TypeConfiguration::AccDataType;
using CDataType = typename TypeConfiguration::CDataType;
ck_tile::index_t m_dim = arg_parser.get_int("m");
ck_tile::index_t n_dim = arg_parser.get_int("n");
ck_tile::index_t k_dim = arg_parser.get_int("k");
ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
ck_tile::index_t stride_b = arg_parser.get_int("stride_b");
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
int warmup_iterations = arg_parser.get_int("warmup");
int repeat_iterations = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
bool flush_cache = arg_parser.get_bool("flush_cache");
ck_tile::StreamKReductionStrategy reduction_strategy =
getReductionStrategyValue(arg_parser.get_str("reduction_strategy"));
stride_a = ck_tile::get_default_stride(m_dim, k_dim, stride_a, is_row_major(a_layout));
stride_b = ck_tile::get_default_stride(k_dim, n_dim, stride_b, is_row_major(b_layout));
stride_c = ck_tile::get_default_stride(m_dim, n_dim, stride_c, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k_host(
ck_tile::host_tensor_descriptor(m_dim, k_dim, stride_a, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n_host(
ck_tile::host_tensor_descriptor(k_dim, n_dim, stride_b, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_device_result(
ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{})));
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_host);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_host);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k_host);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n_host);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n_host);
}
else
{
a_m_k_host.SetZero();
b_k_n_host.SetZero();
}
ck_tile::DeviceMem a_m_k_device_memory(a_m_k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_device_memory(b_k_n_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_device_memory(c_m_n_device_result.get_element_space_size_in_bytes());
a_m_k_device_memory.ToDevice(a_m_k_host.data());
b_k_n_device_memory.ToDevice(b_k_n_host.data());
c_m_n_device_memory.SetZero();
c_m_n_device_result.SetZero();
auto [average_time, num_wgs_per_tile] = invokeGemm<GemmConfiguration,
ADataType,
BDataType,
ck_tile::tuple<>,
AccumulatorDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_m_k_device_memory,
b_k_n_device_memory,
c_m_n_device_memory,
m_dim,
n_dim,
k_dim,
stride_a,
stride_b,
stride_c,
warmup_iterations,
repeat_iterations,
flush_cache,
reduction_strategy);
c_m_n_device_memory.FromDevice(c_m_n_device_result.data());
std::size_t flop = std::size_t(2) * m_dim * n_dim * k_dim;
std::size_t num_byte = sizeof(ADataType) * m_dim * k_dim + sizeof(BDataType) * n_dim * k_dim +
sizeof(CDataType) * m_dim * n_dim;
float tflops = static_cast<float>(flop) / 1.E9 / average_time;
float gb_per_sec = num_byte / 1.E6 / average_time;
std::cout << "Run Gemm kernel with M=" << m_dim << " N=" << n_dim << " K=" << k_dim
<< " StrideA=" << stride_a << " StrideB=" << stride_b << " StrideC=" << stride_c
<< " A_Layout=" << ALayout::name << " B_Layout=" << BLayout::name
<< " C_Layout=" << CLayout::name
<< " A_Type=" << ck_tile::DataTypeTraits<ADataType>::name
<< " B_Type=" << ck_tile::DataTypeTraits<BDataType>::name
<< " C_Type=" << ck_tile::DataTypeTraits<CDataType>::name
<< " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " "
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << average_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
bool pass = false;
// Memory on host to store gpu reference result
ck_tile::HostTensor<CDataType> c_m_n_reference(
ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{})));
c_m_n_reference.SetZero();
if(arg_parser.get_int("v") == 1) // Validate on the CPU
{
ck_tile::reference_gemm<ADataType, BDataType, AccumulatorDataType, CDataType>(
a_m_k_host, b_k_n_host, c_m_n_reference);
const float max_accumulated_value =
*std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end());
const auto relative_absolute_tolerances =
calculateRtolAtol<ADataType, BDataType, AccumulatorDataType, CDataType>(
k_dim, num_wgs_per_tile, max_accumulated_value);
pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "CPU");
}
else if(arg_parser.get_int("v") == 2) // Validate on the GPU
{
// Memory on device to store gpu reference result
ck_tile::DeviceMem c_m_n_gpu_buffer_reference(
c_m_n_reference.get_element_space_size_in_bytes());
c_m_n_gpu_buffer_reference.SetZero();
ADataType* d_A = static_cast<ADataType*>(a_m_k_device_memory.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_device_memory.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buffer_reference.GetDeviceBuffer());
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccumulatorDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
d_A, d_B, d_C, m_dim, n_dim, k_dim, stride_a, stride_b, stride_c);
c_m_n_gpu_buffer_reference.FromDevice(c_m_n_reference.data());
const float max_accumulated_value =
*std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end());
const auto relative_absolute_tolerances =
calculateRtolAtol<ADataType, BDataType, AccumulatorDataType, CDataType>(
k_dim, num_wgs_per_tile, max_accumulated_value);
pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "GPU");
}
return pass;
}

View File

@@ -0,0 +1,236 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gemm_utils.hpp"
#include "ck_tile/ops/common.hpp"
template <typename GemmConfiguration,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccumulatorDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename CDEElementWise,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& stream_config)
{
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<GemmConfiguration::M_TILE,
GemmConfiguration::N_TILE,
GemmConfiguration::K_TILE>,
ck_tile::sequence<GemmConfiguration::M_WARP,
GemmConfiguration::N_WARP,
GemmConfiguration::K_WARP>,
ck_tile::sequence<GemmConfiguration::M_WARP_TILE,
GemmConfiguration::N_WARP_TILE,
GemmConfiguration::K_WARP_TILE>,
GemmConfiguration::PERMUTE_A,
GemmConfiguration::PERMUTE_B>;
using TilePartitioner = ck_tile::
StreamKTilePartitioner<GemmShape, ReductionStrategy, GemmConfiguration::PERSISTENT>;
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<GemmConfiguration::PAD_M,
GemmConfiguration::PAD_N,
GemmConfiguration::PAD_K,
GemmConfiguration::DOUBLE_SMEM_BUFFER,
ALayout,
BLayout,
ELayout,
GemmConfiguration::TRANSPOSE_C,
GemmConfiguration::USE_STRUCTURED_SPARSITY,
GemmConfiguration::PERSISTENT,
GemmConfiguration::NUM_WAVE_GROUPS,
GemmConfiguration::PRESHUFFLE>;
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccumulatorDataType,
GemmShape,
GemmUniversalTraits,
GemmConfiguration::SCHEDULER>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccumulatorDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfiguration::M_WARP,
GemmConfiguration::N_WARP,
GemmConfiguration::M_WARP_TILE,
GemmConfiguration::N_WARP_TILE,
GemmConfiguration::K_WARP_TILE,
UniversalGemmProblem::TransposeC,
GemmConfiguration::NUM_WAVE_GROUPS>>;
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kernel_args = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args);
ck_tile::DeviceMem workspace_data(workspace_size);
workspace_data.SetZero();
kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer();
dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner);
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kernel_args))
{
// Clear the output C tensor results after each repetition of the kernel
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
}
if(stream_config.log_level_ > 0)
{
// Reset sk flags to zero before each repetition of the kernel
workspace_data.SetZero();
}
auto reset_data_buffers = [&]() {
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
{
// Clear the output C tensor results after each repetition of the kernel
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
}
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Linear)
{
// Reset sk flags to zero before each repetition of the kernel
workspace_data.SetZero();
}
};
std::function<void()> preprocess = reset_data_buffers;
float average_time =
ck_tile::launch_kernel_time_mask(stream_config,
preprocess,
ck_tile::make_kernel<GemmConfiguration::BLOCK_PER_CU>(
Kernel{}, grids, blocks, 0, kernel_args));
ck_tile::index_t num_wgs_per_tile = kernel_args.tile_partitioner.estimate_num_wgs_per_tile();
return std::tuple{average_time, num_wgs_per_tile};
}
#include "run_gemm_example.inc"
template <typename GemmConfiguration, typename TypeConfiguration>
int runGemmExamplePrecisionType(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C")
{
return runGemmExampleWithLayouts<GemmConfiguration, TypeConfiguration>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported layouts.");
}
return 0;
}
template <template <typename PrecisionType, bool IsPersistent> typename GemmConfiguration>
int runGemmExample(int argc, char* argv[])
{
auto [result, arg_parser] = createArgs(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
auto persistent_data_parallel = arg_parser.get_bool("persistent_dp");
if(data_type == "bf16")
{
using TypeConfiguration = StreamKGemmTypeConfiguration<ck_tile::bf16_t>;
if(persistent_data_parallel)
{
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf16_t, true>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
else
{
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf16_t, false>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
}
else if(data_type == "fp16")
{
using TypeConfiguration = StreamKGemmTypeConfiguration<ck_tile::half_t>;
if(persistent_data_parallel)
{
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::half_t, true>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
else
{
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::half_t, false>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
}
else if(data_type == "fp8")
{
using TypeConfiguration =
StreamKGemmTypeConfiguration<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
if(persistent_data_parallel)
{
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::fp8_t, true>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
else
{
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::fp8_t, false>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
}
else if(data_type == "bf8")
{
using TypeConfiguration =
StreamKGemmTypeConfiguration<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
if(persistent_data_parallel)
{
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf8_t, true>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
else
{
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf8_t, false>,
TypeConfiguration>(a_layout, b_layout, argc, argv);
}
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
return false;
}
int main(int argc, char* argv[])
{
return !runGemmExample<GemmConfigurationMemoryInterwave>(argc, argv);
}