Code style clean-up and documentation

The following changes were made:
- Clean-up of variable namings
- Addition of README
- Removal of num_cu and occupancy args; such options are meant for
  testing purposes and should not be exposed to the user
- Removal of CK_TILE_PIPELINE_MEMORY macro and PipelineTypeTraits class
  since we only support one pipeline at the moment.
This commit is contained in:
Emily Martins
2025-09-24 15:32:25 +00:00
committed by Emily Martins
parent a3499e38b2
commit 243118c275
4 changed files with 52 additions and 76 deletions

View File

@@ -0,0 +1,37 @@
# Stream-K GEMM
This folder contains examples of Stream-K GEMMs using the ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
# Compile the Stream-K kernels
make tile_example_streamk_gemm_basic -j
```
This will result in an executable `build/bin/tile_example_streamk_gemm_basic`
## example
```
args:
-m m dimension (default:512)
-n n dimension (default:512)
-k k dimension (default:512)
-a_layout tensor A data layout (default: R)
-b_layout tensor B data layout (default: C)
-c_layout tensor C data layout (default: R)
-num_sk_blocks number of Stream-K blocks. -1: chosen by algorithm, or user selected (default:-1)
-reduction_strategy strategy for storing results in C tensor. atomic/reduction (default:atomic)
-stride_a tensor A stride (default:0)
-stride_b tensor B stride (default:0)
-stride_c tensor C stride (default:0)
-v validation strategy. 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
-prec data type. fp16/bf16 (default:fp16)
-warmup number of iterations before benchmarking the kernel (default:50)
-repeat number of iterations to benchmark the kernel (default:100)
-timer timing mode. gpu:gpu timer, cpu:cpu timer (default:gpu)
-init data initialization strategy. 0:random, 1:linear, 2:constant(1) (default:0)
-flush_cache flush the cache before running the kernel (default:true)
```

View File

@@ -2,16 +2,11 @@
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#define CK_TILE_PIPELINE_MEMORY 1
struct GemmConfigBase
{
static constexpr bool kPadM = true;
@@ -27,7 +22,6 @@ struct GemmConfigBase
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool DoubleSmemBuffer = false;
@@ -48,20 +42,7 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <typename ADataType_, typename BDataType_ = ADataType_, typename CDataType_ = ADataType_>
@@ -109,13 +90,6 @@ auto create_args(int argc, char* argv[])
.insert("reduction_strategy",
"atomic",
"strategy for storing results in C tensor - atomic/reduction")
.insert(
"occupancy",
"-1",
"maximum number of workgroups per CU - value of -1 queries occupancy from the device")
.insert("num_cu",
"-1",
"number of compute units (CUs) - value of -1 uses number of CUs on the device")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")

View File

@@ -26,10 +26,10 @@ int estimate_num_wgs_per_tile(const TilePartitioner& tile_partitioner)
}
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
static constexpr inline auto is_row_major(Layout)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
return ck_tile::bool_constant<
std::is_same_v<ck_tile::remove_cvref_t<Layout>, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
@@ -65,10 +65,7 @@ template <typename GemmConfig,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& s,
int num_cu,
int occupancy);
std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s);
template <typename GemmConfig,
typename ADataType,
@@ -94,9 +91,7 @@ std::tuple<float, int> invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_repeat,
bool flush_cache,
ck_tile::StreamKReductionStrategy reduction_strategy,
uint32_t num_sk_blocks,
int num_cu,
int occupancy)
uint32_t num_sk_blocks)
{
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
@@ -126,10 +121,7 @@ std::tuple<float, int> invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
CLayout,
CDEElementWise,
ck_tile::StreamKReductionStrategy::Atomic>(
args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache},
num_cu,
occupancy);
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache});
}
else /*Reduction*/
{
@@ -145,10 +137,7 @@ std::tuple<float, int> invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
CLayout,
CDEElementWise,
ck_tile::StreamKReductionStrategy::Reduction>(
args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache},
num_cu,
occupancy);
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache});
}
return ave_time_and_batch;
@@ -189,15 +178,6 @@ ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string
}
}
void validate_num_cu_and_occupancy(int num_cu, int occupancy)
{
if((num_cu == -1) != (occupancy == -1))
{
throw std::runtime_error("Arguments num_cu and occupancy must both use either (a) "
"default values (-1) or (b) non-default values.");
}
}
template <typename GemmConfig,
typename TypeConfig,
typename ALayout,
@@ -239,10 +219,6 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::StreamKReductionStrategy reduction_strategy =
get_reduction_strategy_value(arg_parser.get_str("reduction_strategy"));
uint32_t num_sk_blocks = static_cast<uint32_t>(arg_parser.get_int("num_sk_blocks"));
int num_cu = arg_parser.get_int("num_cu");
int occupancy = arg_parser.get_int("occupancy");
validate_num_cu_and_occupancy(num_cu, occupancy);
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
@@ -307,9 +283,7 @@ int run_gemm_example_with_layouts(int argc,
n_repeat,
flush_cache,
reduction_strategy,
num_sk_blocks,
num_cu,
occupancy);
num_sk_blocks);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());

View File

@@ -16,10 +16,7 @@ template <typename GemmConfig,
typename ELayout,
typename CDEElementWise,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& s,
int num_cu,
int occupancy)
std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
@@ -45,10 +42,7 @@ std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
const auto Run = [&](const auto memory_operation_) -> std::tuple<float, int> {
constexpr auto memory_operation = memory_operation_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
const auto Run = [&](const auto memory_operation) -> std::tuple<float, int> {
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
@@ -58,10 +52,9 @@ std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
GemmConfig::Scheduler>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
@@ -80,14 +73,12 @@ std::tuple<float, int> gemm(const ck_tile::StreamKHostArgs& args,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
memory_operation.value,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = (num_cu == -1 && occupancy == -1)
? Kernel::MakeKernelArgs(args)
: Kernel::MakeKernelArgs(args, num_cu, occupancy);
auto kargs = Kernel::MakeKernelArgs(args);
dim3 grids = Kernel::GridSize(kargs.tile_partitioner);
dim3 blocks = Kernel::BlockSize();