Merge commit 'bebf0e9d158c13d34c9f263a9551f60fa463bc66' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-29 22:11:28 +00:00
parent 1ff47b0020
commit 78f2779870
11 changed files with 856 additions and 19 deletions

View File

@@ -10,16 +10,15 @@ The grouped GEMM examples include two advanced optimization features:
Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches.
- **Implementation**: Available in `grouped_gemm_preshuffle.cpp`
- **Configuration**: Uses `GemmConfigPreshuffleDecode` template configuration
- **Configuration**: Uses `GemmConfigPreshuffleDecode` and `GemmConfigPreshufflePrefill` template configuration
- **Constraints**: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts
- **Benefits**: Improved memory efficiency and reduced data movement
#### Persistence Mode
Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy.
- **Template Parameter**: Controlled by the `Persistent` boolean template parameter in `invoke_gemm`
- **Usage**: `invoke_gemm<ALayout, BLayout, CLayout, true>` enables persistence
- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes
#### Multi-D Operations
Multi-D operations extend the standard GEMM operation by supporting additional element-wise operations on the result tensor. This feature is particularly useful for workloads that require post-processing of the GEMM output.
@@ -31,7 +30,8 @@ Multi-D operations extend the standard GEMM operation by supporting additional e
- **Benefits**: Enables complex operations like scaling, activation functions, or other element-wise transformations in a single kernel call
- **Build Target**: `make tile_example_grouped_gemm_multi_d -j`
Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads.
Multi-D operations supports both persistence and non-persistence modes.
Weight preshuffle supports only on non-persistence mode.
## Build
```
@@ -48,7 +48,7 @@ make tile_example_grouped_gemm_multi_d -j
# The quant grouped gemm fp8 example
make tile_example_quant_grouped_gemm -j
```
This will result in an executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`.
Each example will result in an corresponding executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`.
## example

View File

@@ -166,6 +166,112 @@ float grouped_gemm_multi_d(const std::vector<grouped_gemm_multi_d_kargs>& gemm_d
return ave_time;
}
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename CDEElementWise>
float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using GemmUniversalTraits =
ck_tile::PersistentTileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout>;
float ave_time{0};
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
// We create the GEMM pipeline without specifying hotloop or tailnumber.
// These are automatically run inside the kernel based on the given input data.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
return ave_time;
};
if(!splitk)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
return ave_time;
}
#include "run_grouped_gemm_multi_d_example.inc"
int main(int argc, char* argv[])

View File

@@ -95,6 +95,7 @@ struct GemmConfigV3 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool Persistent = true;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
@@ -170,7 +171,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs<2>;
using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs<DsDataType::size()>;
std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
{
@@ -201,7 +202,7 @@ std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_multi_d_kargs>& gemm_descs)
{
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<2>);
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<DsDataType::size()>);
}
template <typename GemmConfig,

View File

@@ -86,9 +86,43 @@ float invoke_gemm(int n_warmup,
}
else
{
(void)group_count;
// not supported yet
throw std::runtime_error("Persistent grouped gemm multiple-d is not supported yet");
std::vector<ck_tile::GemmTransKernelArg<DsDataType::size()>> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
const bool splitk = args[0].k_batch > 1;
for(const auto& arg : args)
{
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, 2>{{arg.a_ptr},
{arg.b_ptr},
arg.ds_ptr,
arg.e_ptr,
arg.M,
arg.N,
arg.K,
{arg.stride_A},
{arg.stride_B},
arg.stride_Ds,
arg.stride_E,
arg.k_batch});
}
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
HIP_CHECK_ERROR(hipMemcpyWithStream(
kargs_ptr,
kargs.data(),
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<DsDataType::size()>),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time =
grouped_gemm_multi_d_tileloop<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
CDEElementWise>(stream, group_count, kargs_ptr, splitk);
}
return ave_time;
}
@@ -322,12 +356,6 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc,
b_k_n_tensors[i],
{d0_m_n_tensors[i], d1_m_n_tensors[i]},
e_m_n_host_refs[i]);
std::cout << "e_m_n_host_refs[i]: " << std::endl;
e_m_n_host_refs[i].print_first_n(std::cout, 10);
std::cout << std::endl;
std::cout << "e_m_n_tensors[i]: " << std::endl;
e_m_n_tensors[i].print_first_n(std::cout, 10);
std::cout << std::endl;
const float max_accumulated_value =
*std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end());

View File

@@ -0,0 +1,5 @@
if(GPU_TARGETS MATCHES "gfx9")
add_executable(tile_example_streamk_gemm_basic EXCLUDE_FROM_ALL streamk_gemm_basic.cpp)
else()
message(DEBUG "Skipping ck_tile streamk gemm tests for current target")
endif()

View File

@@ -0,0 +1,37 @@
# Stream-K GEMM
This folder contains examples of Stream-K GEMMs using the ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
# Compile the Stream-K kernels
make tile_example_streamk_gemm_basic -j
```
This will result in an executable `build/bin/tile_example_streamk_gemm_basic`
## example
```
args:
-m m dimension (default:512)
-n n dimension (default:512)
-k k dimension (default:512)
-a_layout tensor A data layout (default: R)
-b_layout tensor B data layout (default: C)
-c_layout tensor C data layout (default: R)
-num_sk_blocks number of Stream-K blocks. -1: chosen by algorithm, or user selected (default:-1)
-reduction_strategy strategy for storing results in C tensor. atomic/reduction (default:atomic)
-stride_a tensor A stride (default:0)
-stride_b tensor B stride (default:0)
-stride_c tensor C stride (default:0)
-v validation strategy. 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
-prec data type. fp16/bf16 (default:fp16)
-warmup number of iterations before benchmarking the kernel (default:50)
-repeat number of iterations to benchmark the kernel (default:100)
-timer timing mode. gpu:gpu timer, cpu:cpu timer (default:gpu)
-init data initialization strategy. 0:random, 1:linear, 2:constant(1) (default:0)
-flush_cache flush the cache before running the kernel (default:true)
```

View File

@@ -0,0 +1,106 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
struct GemmConfigBase
{
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr bool Persistent = false;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename PrecType>
struct GemmConfigMemoryInterwave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <typename ADataType_, typename BDataType_ = ADataType_, typename CDataType_ = ADataType_>
struct StreamKGemmTypeConfig
{
using ADataType = ADataType_;
using BDataType = BDataType_;
using AccDataType = float;
using CDataType = CDataType_;
};
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "512", "m dimension")
.insert("n", "512", "n dimension")
.insert("k", "512", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Column by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("num_sk_blocks",
"-1",
"number of Stream-K blocks. -1: chosen by algorithm, or user selected")
.insert("reduction_strategy",
"atomic",
"strategy for storing results in C tensor - atomic/reduction")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16")
.insert("warmup", "50", "number of iterations before benchmarking the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}

View File

@@ -0,0 +1,351 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
// Estimate the number of WGs contributing to the same macro tile in C
template <ck_tile::StreamKReductionStrategy ReductionStrategy, typename TilePartitioner>
int estimate_num_wgs_per_tile(const TilePartitioner& tile_partitioner)
{
// In the case of non-atomic reduction or DP only, there will always be 1 WG contributing to a
// macro time in C
int num_wgs_per_tile = 1;
// Otherwise, for atomics, multiple WGs may be contributing to the same macro tile in C
if(tile_partitioner.sk_num_blocks > 0 &&
ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
{
// Determine the number of iterations per WG for a given macro tile in C
uint32_t k_iters_per_block = tile_partitioner.k_iters_per_big_block - 1;
// Estimate the number of WGs per macro tile
num_wgs_per_tile = (tile_partitioner.k_iters_per_tile.get() / (k_iters_per_block)) +
((tile_partitioner.k_iters_per_tile.get() % k_iters_per_block) != 0);
}
return std::max(num_wgs_per_tile, 1);
}
template <typename Layout>
static constexpr inline auto is_row_major(Layout)
{
return ck_tile::bool_constant<
std::is_same_v<ck_tile::remove_cvref_t<Layout>, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to multiple WGs working in the same C macro tile
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s);
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
std::tuple<float, int> invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t stride_A,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
int n_warmup,
int n_repeat,
bool flush_cache,
ck_tile::StreamKReductionStrategy reduction_strategy,
uint32_t num_sk_blocks)
{
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
stride_C,
reduction_strategy,
num_sk_blocks};
std::tuple<float, int> ave_time_and_batch;
if(args.reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
{
ave_time_and_batch = gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
CDEElementWise,
ck_tile::StreamKReductionStrategy::Atomic>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache});
}
else /*Reduction*/
{
ave_time_and_batch = gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
CDEElementWise,
ck_tile::StreamKReductionStrategy::Reduction>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache});
}
return ave_time_and_batch;
}
template <typename CDataType>
bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
const ck_tile::HostTensor<CDataType>& c_m_n_ref,
const ck_tile::tuple<double, double>& rtol_atol,
const char* variant)
{
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The " << variant << " verification result is:" << (pass ? "correct" : "fail")
<< std::endl;
return pass;
}
ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string& strategy)
{
if(strategy == "atomic")
{
return ck_tile::StreamKReductionStrategy::Atomic;
}
else if(strategy == "reduction")
{
return ck_tile::StreamKReductionStrategy::Reduction;
}
else
{
throw std::runtime_error("Unsupported Stream-K reduction strategy !!!");
}
}
template <typename GemmConfig,
typename TypeConfig,
typename ALayout,
typename BLayout,
typename CLayout>
int run_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
static_assert(!GemmConfig::Preshuffle, "Not implemented");
static_assert(!GemmConfig::UseStructuredSparsity, "Not implemented");
static_assert(!GemmConfig::PermuteA, "Not implemented");
static_assert(!GemmConfig::PermuteB, "Not implemented");
using ADataType = typename TypeConfig::ADataType;
using BDataType = typename TypeConfig::BDataType;
using AccDataType = typename TypeConfig::AccDataType;
using CDataType = typename TypeConfig::CDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
bool flush_cache = arg_parser.get_bool("flush_cache");
ck_tile::StreamKReductionStrategy reduction_strategy =
get_reduction_strategy_value(arg_parser.get_str("reduction_strategy"));
uint32_t num_sk_blocks = static_cast<uint32_t>(arg_parser.get_int("num_sk_blocks"));
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n);
}
else
{
a_m_k.SetZero();
b_k_n.SetZero();
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
auto [ave_time, num_wgs_per_tile] = invoke_gemm<GemmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
n_warmup,
n_repeat,
flush_cache,
reduction_strategy,
num_sk_blocks);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
<< " A_Layout=" << ALayout::name << " B_Layout=" << BLayout::name
<< " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits<ADataType>::name
<< " B_Type=" << DataTypeTraits<BDataType>::name
<< " C_Type=" << DataTypeTraits<CDataType>::name
<< " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
bool pass = true;
// Memory on host to store gpu reference result
ck_tile::HostTensor<CDataType> c_m_n_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_ref.SetZero();
if(arg_parser.get_int("v") == 1) // Validate on the CPU
{
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, num_wgs_per_tile, max_accumulated_value);
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU");
}
else if(arg_parser.get_int("v") == 2) // Validate on the GPU
{
// Memory on device to store gpu reference result
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes());
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data());
const float max_accumulated_value =
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, num_wgs_per_tile, max_accumulated_value);
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU");
}
return pass;
}

View File

@@ -0,0 +1,193 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gemm_utils.hpp"
#include "run_gemm_example.inc"
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename CDEElementWise,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
using TilePartitioner = ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
GemmConfig::Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
const auto Run = [&](const auto memory_operation) -> std::tuple<float, int> {
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
GemmConfig::Scheduler>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation.value,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
dim3 grids = Kernel::GridSize(kargs.tile_partitioner);
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
// Function to clear the output C tensor results after each repetition of the kernel
auto clear_gemm_output = [&]() {
if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
std::function<void()> preprocess = clear_gemm_output;
float ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
int num_wgs_per_tile = estimate_num_wgs_per_tile<ReductionStrategy>(kargs.tile_partitioner);
return std::tuple{ave_time, num_wgs_per_tile};
};
if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
// Since we are doing stream K, in the case of
// atomics, multiple workgroups may write to the same
// output tile in the C tensor, so we must atomic add
// the results (not set)
ck_tile::memory_operation_enum::atomic_add>{});
}
else // We are using ck_tile::StreamKReductionStrategy::Reduction
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
// In this case, there is only ever 1 WG writing final
// results to each macro tile in the C tensor, so we
// can do a set.
ck_tile::memory_operation_enum::set>{});
}
}
template <typename GemmConfig, typename TypeConfig>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, TypeConfig>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported layouts.");
}
return 0;
}
template <template <typename PreType> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(data_type == "bf16")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf16_t>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp16")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::half_t>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
return false;
}
int main(int argc, char* argv[])
{
return !run_gemm_example<GemmConfigMemoryInterwave>(argc, argv);
}

View File

@@ -25,3 +25,4 @@ add_subdirectory(22_gemm_multi_abd)
add_subdirectory(35_batched_transpose)
add_subdirectory(38_block_scale_gemm)
add_subdirectory(39_copy)
add_subdirectory(40_streamk_gemm)

View File

@@ -324,10 +324,18 @@ struct GroupedGemmKernel
}
else // SingleSmemBuffer
{
if constexpr(UsePersistentKernel)
{
RunGemmWithPipelineSelection(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
RunGemmWithPipelineSelection(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
kargs,
splitk_batch_offset,
i_m,
i_n);
}
else // Non-persistent kernel
{
@@ -365,6 +373,7 @@ struct GroupedGemmKernel
CK_TILE_DEVICE static void
RunGemmWithPipelineSelection(const ADataType* a_ptr,
const BDataType* b_ptr,
const std::array<const void*, NumDTensor_>& ds_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const UniversalGemmKernelArgs<1, 1, NumDTensor_>& kargs,
@@ -375,7 +384,7 @@ struct GroupedGemmKernel
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k);
{a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k);
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =