mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
This commit is contained in:
13
example/ck_tile/40_streamk_gemm/CMakeLists.txt
Normal file
13
example/ck_tile/40_streamk_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(tile_example_streamk_gemm_basic streamk_gemm_basic.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
target_compile_options(tile_example_streamk_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile streamk gemm tests for current target")
|
||||
endif()
|
||||
37
example/ck_tile/40_streamk_gemm/README.md
Normal file
37
example/ck_tile/40_streamk_gemm/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
# Stream-K GEMM
|
||||
|
||||
This folder contains examples of Stream-K GEMMs using the ck_tile tile-programming implementation.
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx942) or leave it blank
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# Compile the Stream-K kernels
|
||||
make tile_example_streamk_gemm_basic -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_streamk_gemm_basic`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-m m dimension (default:512)
|
||||
-n n dimension (default:512)
|
||||
-k k dimension (default:512)
|
||||
-a_layout tensor A data layout (default: R)
|
||||
-b_layout tensor B data layout (default: C)
|
||||
-c_layout tensor C data layout (default: R)
|
||||
-reduction_strategy strategy for storing results in C tensor. atomic/linear (default:atomic)
|
||||
-persistent_dp persistent strategy for data-parallel section. Set to 0 for non-persistent or to 1 for persistent. (default:0)
|
||||
-stride_a tensor A stride (default:0)
|
||||
-stride_b tensor B stride (default:0)
|
||||
-stride_c tensor C stride (default:0)
|
||||
-v validation strategy. 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1)
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-warmup number of iterations before benchmarking the kernel (default:50)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer timing mode. gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-init data initialization strategy. 0:random, 1:linear, 2:constant(1) (default:0)
|
||||
-flush_cache flush the cache before running the kernel (default:true)
|
||||
```
|
||||
85
example/ck_tile/40_streamk_gemm/gemm_utils.hpp
Normal file
85
example/ck_tile/40_streamk_gemm/gemm_utils.hpp
Normal file
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
struct GemmConfigurationBase
|
||||
{
|
||||
static constexpr bool PAD_M = true;
|
||||
static constexpr bool PAD_N = true;
|
||||
static constexpr bool PAD_K = true;
|
||||
|
||||
static constexpr bool PERMUTE_A = false;
|
||||
static constexpr bool PERMUTE_B = false;
|
||||
|
||||
static constexpr bool TRANSPOSE_C = false;
|
||||
static constexpr bool USE_STRUCTURED_SPARSITY = false;
|
||||
|
||||
static constexpr int BLOCK_PER_CU = 1;
|
||||
static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NUM_WAVE_GROUPS = 1;
|
||||
static constexpr bool PRESHUFFLE = false;
|
||||
static constexpr bool DOUBLE_SMEM_BUFFER = false;
|
||||
};
|
||||
|
||||
template <typename PrecisionType, bool IsPersistent>
|
||||
struct GemmConfigurationMemoryInterwave : public GemmConfigurationBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_TILE = 256;
|
||||
static constexpr ck_tile::index_t N_TILE = 256;
|
||||
static constexpr ck_tile::index_t K_TILE = 16;
|
||||
|
||||
static constexpr ck_tile::index_t M_WARP = 2;
|
||||
static constexpr ck_tile::index_t N_WARP = 2;
|
||||
static constexpr ck_tile::index_t K_WARP = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_WARP_TILE = 32;
|
||||
static constexpr ck_tile::index_t N_WARP_TILE = 32;
|
||||
static constexpr ck_tile::index_t K_WARP_TILE = sizeof(PrecisionType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool PERSISTENT = IsPersistent;
|
||||
static constexpr auto SCHEDULER = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
template <typename ADataType_, typename BDataType_ = ADataType_, typename CDataType_ = ADataType_>
|
||||
struct StreamKGemmTypeConfiguration
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = float;
|
||||
using CDataType = CDataType_;
|
||||
};
|
||||
|
||||
auto createArgs(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "512", "m dimension")
|
||||
.insert("n", "512", "n dimension")
|
||||
.insert("k", "512", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Column by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("reduction_strategy",
|
||||
"atomic",
|
||||
"strategy for storing results in C tensor - atomic/linear")
|
||||
.insert("persistent_dp",
|
||||
"0",
|
||||
"0. Non-persistent data-parallel section, 1 Fully persistent kernel.")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmarking the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
334
example/ck_tile/40_streamk_gemm/run_gemm_example.inc
Normal file
334
example/ck_tile/40_streamk_gemm/run_gemm_example.inc
Normal file
@@ -0,0 +1,334 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout)
|
||||
{
|
||||
return ck_tile::bool_constant<
|
||||
std::is_same_v<ck_tile::remove_cvref_t<Layout>, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculateRtolAtol(const ck_tile::index_t k_dim,
|
||||
const ck_tile::index_t k_batch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto relative_tolerance =
|
||||
ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(k_dim, k_batch));
|
||||
const auto absolute_tolerance =
|
||||
ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / k_batch, ck_tile::integer_divide_ceil(k_dim, k_batch));
|
||||
// Calculate error due to multiple WGs working in the same C macro tile
|
||||
const auto relative_tolerance_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(k_batch);
|
||||
const auto absolute_tolerance_split_k =
|
||||
ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(max_accumulated_value,
|
||||
k_batch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(relative_tolerance, relative_tolerance_split_k),
|
||||
std::max(absolute_tolerance, absolute_tolerance_split_k));
|
||||
}
|
||||
|
||||
template <typename GemmConfiguration,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccumulatorDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& stream_config);
|
||||
|
||||
template <typename GemmConfiguration,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccumulatorDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
std::tuple<float, ck_tile::index_t> invokeGemm(ck_tile::DeviceMem& a_m_k_device_memory,
|
||||
ck_tile::DeviceMem& b_k_n_device_memory,
|
||||
ck_tile::DeviceMem& c_m_n_device_memory,
|
||||
ck_tile::index_t m_dim,
|
||||
ck_tile::index_t n_dim,
|
||||
ck_tile::index_t k_dim,
|
||||
ck_tile::index_t stride_a,
|
||||
ck_tile::index_t stride_b,
|
||||
ck_tile::index_t stride_c,
|
||||
int warmup_iterations,
|
||||
int repeat_iterations,
|
||||
bool flush_cache,
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy)
|
||||
{
|
||||
ck_tile::StreamKHostArgs args{a_m_k_device_memory.GetDeviceBuffer(),
|
||||
b_k_n_device_memory.GetDeviceBuffer(),
|
||||
c_m_n_device_memory.GetDeviceBuffer(),
|
||||
m_dim,
|
||||
n_dim,
|
||||
k_dim,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c};
|
||||
|
||||
std::tuple<float, ck_tile::index_t> average_time_and_batch;
|
||||
|
||||
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
average_time_and_batch = gemm<GemmConfiguration,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args,
|
||||
ck_tile::stream_config{
|
||||
nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache});
|
||||
}
|
||||
else /*Reduction*/
|
||||
{
|
||||
average_time_and_batch = gemm<GemmConfiguration,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy::Linear>(
|
||||
args,
|
||||
ck_tile::stream_config{
|
||||
nullptr, true, 1, warmup_iterations, repeat_iterations, true, flush_cache});
|
||||
}
|
||||
|
||||
return average_time_and_batch;
|
||||
}
|
||||
|
||||
template <typename CDataType>
|
||||
bool doVerify(const ck_tile::HostTensor<CDataType>& c_m_n_device_result,
|
||||
const ck_tile::HostTensor<CDataType>& c_m_n_reference,
|
||||
const ck_tile::tuple<double, double>& relative_absolute_tolerances,
|
||||
const char* variant)
|
||||
{
|
||||
bool pass = ck_tile::check_err(c_m_n_device_result,
|
||||
c_m_n_reference,
|
||||
"Error: Incorrect results!",
|
||||
relative_absolute_tolerances.at(ck_tile::number<0>{}),
|
||||
relative_absolute_tolerances.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: "
|
||||
<< relative_absolute_tolerances.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: "
|
||||
<< relative_absolute_tolerances.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "The " << variant << " verification result is:" << (pass ? "correct" : "fail")
|
||||
<< std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
ck_tile::StreamKReductionStrategy getReductionStrategyValue(const std::string& strategy)
|
||||
{
|
||||
if(strategy == "atomic")
|
||||
{
|
||||
return ck_tile::StreamKReductionStrategy::Atomic;
|
||||
}
|
||||
else if(strategy == "linear")
|
||||
{
|
||||
return ck_tile::StreamKReductionStrategy::Linear;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported Stream-K reduction strategy !!!");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfiguration,
|
||||
typename TypeConfiguration,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int runGemmExampleWithLayouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = createArgs(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
static_assert(!GemmConfiguration::PRESHUFFLE, "Not implemented");
|
||||
static_assert(!GemmConfiguration::USE_STRUCTURED_SPARSITY, "Not implemented");
|
||||
static_assert(!GemmConfiguration::PERMUTE_A, "Not implemented");
|
||||
static_assert(!GemmConfiguration::PERMUTE_B, "Not implemented");
|
||||
|
||||
using ADataType = typename TypeConfiguration::ADataType;
|
||||
using BDataType = typename TypeConfiguration::BDataType;
|
||||
using AccumulatorDataType = typename TypeConfiguration::AccDataType;
|
||||
using CDataType = typename TypeConfiguration::CDataType;
|
||||
|
||||
ck_tile::index_t m_dim = arg_parser.get_int("m");
|
||||
ck_tile::index_t n_dim = arg_parser.get_int("n");
|
||||
ck_tile::index_t k_dim = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_a = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_b = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_c = arg_parser.get_int("stride_c");
|
||||
int warmup_iterations = arg_parser.get_int("warmup");
|
||||
int repeat_iterations = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
bool flush_cache = arg_parser.get_bool("flush_cache");
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy =
|
||||
getReductionStrategyValue(arg_parser.get_str("reduction_strategy"));
|
||||
|
||||
stride_a = ck_tile::get_default_stride(m_dim, k_dim, stride_a, is_row_major(a_layout));
|
||||
stride_b = ck_tile::get_default_stride(k_dim, n_dim, stride_b, is_row_major(b_layout));
|
||||
stride_c = ck_tile::get_default_stride(m_dim, n_dim, stride_c, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k_host(
|
||||
ck_tile::host_tensor_descriptor(m_dim, k_dim, stride_a, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n_host(
|
||||
ck_tile::host_tensor_descriptor(k_dim, n_dim, stride_b, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_device_result(
|
||||
ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{})));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n_host);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k_host);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n_host);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n_host);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_host.SetZero();
|
||||
b_k_n_host.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_device_memory(a_m_k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_device_memory(b_k_n_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_device_memory(c_m_n_device_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_device_memory.ToDevice(a_m_k_host.data());
|
||||
b_k_n_device_memory.ToDevice(b_k_n_host.data());
|
||||
c_m_n_device_memory.SetZero();
|
||||
c_m_n_device_result.SetZero();
|
||||
auto [average_time, num_wgs_per_tile] = invokeGemm<GemmConfiguration,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_m_k_device_memory,
|
||||
b_k_n_device_memory,
|
||||
c_m_n_device_memory,
|
||||
m_dim,
|
||||
n_dim,
|
||||
k_dim,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
warmup_iterations,
|
||||
repeat_iterations,
|
||||
flush_cache,
|
||||
reduction_strategy);
|
||||
|
||||
c_m_n_device_memory.FromDevice(c_m_n_device_result.data());
|
||||
|
||||
std::size_t flop = std::size_t(2) * m_dim * n_dim * k_dim;
|
||||
std::size_t num_byte = sizeof(ADataType) * m_dim * k_dim + sizeof(BDataType) * n_dim * k_dim +
|
||||
sizeof(CDataType) * m_dim * n_dim;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / average_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / average_time;
|
||||
std::cout << "Run Gemm kernel with M=" << m_dim << " N=" << n_dim << " K=" << k_dim
|
||||
<< " StrideA=" << stride_a << " StrideB=" << stride_b << " StrideC=" << stride_c
|
||||
<< " A_Layout=" << ALayout::name << " B_Layout=" << BLayout::name
|
||||
<< " C_Layout=" << CLayout::name
|
||||
<< " A_Type=" << ck_tile::DataTypeTraits<ADataType>::name
|
||||
<< " B_Type=" << ck_tile::DataTypeTraits<BDataType>::name
|
||||
<< " C_Type=" << ck_tile::DataTypeTraits<CDataType>::name
|
||||
<< " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " "
|
||||
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << average_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
bool pass = false;
|
||||
|
||||
// Memory on host to store gpu reference result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_reference(
|
||||
ck_tile::host_tensor_descriptor(m_dim, n_dim, stride_c, is_row_major(CLayout{})));
|
||||
c_m_n_reference.SetZero();
|
||||
|
||||
if(arg_parser.get_int("v") == 1) // Validate on the CPU
|
||||
{
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccumulatorDataType, CDataType>(
|
||||
a_m_k_host, b_k_n_host, c_m_n_reference);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end());
|
||||
const auto relative_absolute_tolerances =
|
||||
calculateRtolAtol<ADataType, BDataType, AccumulatorDataType, CDataType>(
|
||||
k_dim, num_wgs_per_tile, max_accumulated_value);
|
||||
pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "CPU");
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2) // Validate on the GPU
|
||||
{
|
||||
// Memory on device to store gpu reference result
|
||||
ck_tile::DeviceMem c_m_n_gpu_buffer_reference(
|
||||
c_m_n_reference.get_element_space_size_in_bytes());
|
||||
c_m_n_gpu_buffer_reference.SetZero();
|
||||
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_device_memory.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_device_memory.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buffer_reference.GetDeviceBuffer());
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
d_A, d_B, d_C, m_dim, n_dim, k_dim, stride_a, stride_b, stride_c);
|
||||
c_m_n_gpu_buffer_reference.FromDevice(c_m_n_reference.data());
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_reference.mData.begin(), c_m_n_reference.mData.end());
|
||||
const auto relative_absolute_tolerances =
|
||||
calculateRtolAtol<ADataType, BDataType, AccumulatorDataType, CDataType>(
|
||||
k_dim, num_wgs_per_tile, max_accumulated_value);
|
||||
pass = doVerify(c_m_n_device_result, c_m_n_reference, relative_absolute_tolerances, "GPU");
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
236
example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp
Normal file
236
example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp
Normal file
@@ -0,0 +1,236 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gemm_utils.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
template <typename GemmConfiguration,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccumulatorDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& stream_config)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<GemmConfiguration::M_TILE,
|
||||
GemmConfiguration::N_TILE,
|
||||
GemmConfiguration::K_TILE>,
|
||||
ck_tile::sequence<GemmConfiguration::M_WARP,
|
||||
GemmConfiguration::N_WARP,
|
||||
GemmConfiguration::K_WARP>,
|
||||
ck_tile::sequence<GemmConfiguration::M_WARP_TILE,
|
||||
GemmConfiguration::N_WARP_TILE,
|
||||
GemmConfiguration::K_WARP_TILE>,
|
||||
GemmConfiguration::PERMUTE_A,
|
||||
GemmConfiguration::PERMUTE_B>;
|
||||
|
||||
using TilePartitioner = ck_tile::
|
||||
StreamKTilePartitioner<GemmShape, ReductionStrategy, GemmConfiguration::PERSISTENT>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<GemmConfiguration::PAD_M,
|
||||
GemmConfiguration::PAD_N,
|
||||
GemmConfiguration::PAD_K,
|
||||
GemmConfiguration::DOUBLE_SMEM_BUFFER,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfiguration::TRANSPOSE_C,
|
||||
GemmConfiguration::USE_STRUCTURED_SPARSITY,
|
||||
GemmConfiguration::PERSISTENT,
|
||||
GemmConfiguration::NUM_WAVE_GROUPS,
|
||||
GemmConfiguration::PRESHUFFLE>;
|
||||
|
||||
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
|
||||
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
|
||||
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
|
||||
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccumulatorDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfiguration::SCHEDULER>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccumulatorDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfiguration::M_WARP,
|
||||
GemmConfiguration::N_WARP,
|
||||
GemmConfiguration::M_WARP_TILE,
|
||||
GemmConfiguration::N_WARP_TILE,
|
||||
GemmConfiguration::K_WARP_TILE,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfiguration::NUM_WAVE_GROUPS>>;
|
||||
|
||||
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kernel_args = Kernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args);
|
||||
ck_tile::DeviceMem workspace_data(workspace_size);
|
||||
workspace_data.SetZero();
|
||||
kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer();
|
||||
|
||||
dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner);
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kernel_args))
|
||||
{
|
||||
// Clear the output C tensor results after each repetition of the kernel
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
}
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
// Reset sk flags to zero before each repetition of the kernel
|
||||
workspace_data.SetZero();
|
||||
}
|
||||
|
||||
auto reset_data_buffers = [&]() {
|
||||
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// Clear the output C tensor results after each repetition of the kernel
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_));
|
||||
}
|
||||
else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Linear)
|
||||
{
|
||||
// Reset sk flags to zero before each repetition of the kernel
|
||||
workspace_data.SetZero();
|
||||
}
|
||||
};
|
||||
|
||||
std::function<void()> preprocess = reset_data_buffers;
|
||||
|
||||
float average_time =
|
||||
ck_tile::launch_kernel_time_mask(stream_config,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfiguration::BLOCK_PER_CU>(
|
||||
Kernel{}, grids, blocks, 0, kernel_args));
|
||||
|
||||
ck_tile::index_t num_wgs_per_tile = kernel_args.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
return std::tuple{average_time, num_wgs_per_tile};
|
||||
}
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename GemmConfiguration, typename TypeConfiguration>
|
||||
int runGemmExamplePrecisionType(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return runGemmExampleWithLayouts<GemmConfiguration, TypeConfiguration>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported layouts.");
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <template <typename PrecisionType, bool IsPersistent> typename GemmConfiguration>
|
||||
int runGemmExample(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = createArgs(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
auto persistent_data_parallel = arg_parser.get_bool("persistent_dp");
|
||||
|
||||
if(data_type == "bf16")
|
||||
{
|
||||
using TypeConfiguration = StreamKGemmTypeConfiguration<ck_tile::bf16_t>;
|
||||
if(persistent_data_parallel)
|
||||
{
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf16_t, true>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf16_t, false>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
else if(data_type == "fp16")
|
||||
{
|
||||
using TypeConfiguration = StreamKGemmTypeConfiguration<ck_tile::half_t>;
|
||||
if(persistent_data_parallel)
|
||||
{
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::half_t, true>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::half_t, false>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfiguration =
|
||||
StreamKGemmTypeConfiguration<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
|
||||
if(persistent_data_parallel)
|
||||
{
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::fp8_t, true>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::fp8_t, false>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfiguration =
|
||||
StreamKGemmTypeConfiguration<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
|
||||
if(persistent_data_parallel)
|
||||
{
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf8_t, true>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return runGemmExamplePrecisionType<GemmConfiguration<ck_tile::bf8_t, false>,
|
||||
TypeConfiguration>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return !runGemmExample<GemmConfigurationMemoryInterwave>(argc, argv);
|
||||
}
|
||||
Reference in New Issue
Block a user