mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
Introduces the new partitioner to implement the reduction StreamK kernel. (#3107)
* Introduces the new partitioner to implement the reduction StreamK kernel
* Add more doc text to functions
* Add persistent-dp option to streamk example
* Update example/ck_tile/40_streamk_gemm/README.md
[ROCm/composable_kernel commit: 5abe4109e0]
This commit is contained in:
@@ -22,8 +22,8 @@ args:
|
||||
-a_layout tensor A data layout (default: R)
|
||||
-b_layout tensor B data layout (default: C)
|
||||
-c_layout tensor C data layout (default: R)
|
||||
-num_sk_blocks number of Stream-K blocks. -1: chosen by algorithm, or user selected (default:-1)
|
||||
-reduction_strategy strategy for storing results in C tensor. atomic/reduction (default:atomic)
|
||||
-persistent_dp persistent strategy for data-parallel section. Set to 0 for non-persistent or to 1 for persistent. (default:0)
|
||||
-stride_a tensor A stride (default:0)
|
||||
-stride_b tensor B stride (default:0)
|
||||
-stride_c tensor C stride (default:0)
|
||||
|
||||
@@ -18,7 +18,6 @@ struct GemmConfigBase
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr bool Persistent = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
@@ -27,12 +26,12 @@ struct GemmConfigBase
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
template <typename PrecType, bool Persistent_>
|
||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 16;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
@@ -42,7 +41,8 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr bool Persistent = Persistent_;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
};
|
||||
|
||||
template <typename ADataType_, typename BDataType_ = ADataType_, typename CDataType_ = ADataType_>
|
||||
@@ -96,12 +96,12 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Column by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("num_sk_blocks",
|
||||
"-1",
|
||||
"number of Stream-K blocks. -1: chosen by algorithm, or user selected")
|
||||
.insert("reduction_strategy",
|
||||
"atomic",
|
||||
"strategy for storing results in C tensor - atomic/reduction")
|
||||
.insert("persistent_dp",
|
||||
"0",
|
||||
"0. Non-persistent data-parallel section, 1 Fully persistent kernel.")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
|
||||
@@ -69,20 +69,18 @@ invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
bool flush_cache,
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy,
|
||||
uint32_t num_sk_blocks)
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy)
|
||||
{
|
||||
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
reduction_strategy,
|
||||
num_sk_blocks};
|
||||
ck_tile::reboot::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
reduction_strategy};
|
||||
|
||||
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
|
||||
|
||||
@@ -197,7 +195,6 @@ int run_gemm_example_with_layouts(int argc,
|
||||
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy =
|
||||
get_reduction_strategy_value(arg_parser.get_str("reduction_strategy"));
|
||||
uint32_t num_sk_blocks = static_cast<uint32_t>(arg_parser.get_int("num_sk_blocks"));
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
@@ -261,8 +258,7 @@ int run_gemm_example_with_layouts(int argc,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
flush_cache,
|
||||
reduction_strategy,
|
||||
num_sk_blocks);
|
||||
reduction_strategy);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
@@ -279,10 +275,10 @@ int run_gemm_example_with_layouts(int argc,
|
||||
<< " B_Type=" << DataTypeTraits<BDataType>::name
|
||||
<< " C_Type=" << DataTypeTraits<CDataType>::name
|
||||
<< " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " "
|
||||
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time
|
||||
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
bool pass = false;
|
||||
|
||||
// Memory on host to store gpu reference result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_ref(
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gemm_utils.hpp"
|
||||
#include "run_gemm_example.inc"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
@@ -17,9 +16,8 @@ template <typename GemmConfig,
|
||||
typename ELayout,
|
||||
typename CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
@@ -29,7 +27,8 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy>;
|
||||
using TilePartitioner =
|
||||
ck_tile::StreamKTilePartitioner_v2<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
@@ -78,9 +77,13 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
memory_operation.value,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
|
||||
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
using Kernel = ck_tile::reboot::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
|
||||
ck_tile::DeviceMem workspace_data(workspace_size);
|
||||
workspace_data.SetZero();
|
||||
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
|
||||
|
||||
dim3 grids = Kernel::GridSize(kargs.tile_partitioner);
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
@@ -101,28 +104,28 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Function to clear the output C tensor results after each repetition of the kernel
|
||||
auto clear_gemm_output = [&]() {
|
||||
auto reset_data_buffers = [&]() {
|
||||
if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// Clear the output C tensor results after each repetition of the kernel
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
}
|
||||
else if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
// Reset sk flags to zero before each repetition of the kernel
|
||||
workspace_data.SetZero();
|
||||
}
|
||||
};
|
||||
|
||||
std::function<void()> preprocess = clear_gemm_output;
|
||||
std::function<void()> preprocess = reset_data_buffers;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
ck_tile::index_t num_wgs_per_tile = ck_tile::estimate_num_wgs_per_tile<ReductionStrategy>(
|
||||
kargs.tile_partitioner.sk_num_blocks,
|
||||
// k_iters_per_big_block could be 1, which indicates that all Stream-K workgroups are
|
||||
// big and each does one iteration. Thus, we ensure the value passed in is at least 1 to
|
||||
// avoid division by zero errors.
|
||||
ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u),
|
||||
kargs.tile_partitioner.k_iters_per_tile.get());
|
||||
|
||||
ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
return std::tuple{ave_time, num_wgs_per_tile};
|
||||
};
|
||||
|
||||
@@ -145,6 +148,8 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
template <typename GemmConfig, typename TypeConfig>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
@@ -164,7 +169,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
template <template <typename PreType, bool Persistent_> typename GemmConfig>
|
||||
int run_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -174,30 +179,63 @@ int run_gemm_example(int argc, char* argv[])
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
auto persistent_dp = arg_parser.get_bool("persistent_dp");
|
||||
|
||||
if(data_type == "bf16")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf16_t>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
if(persistent_dp)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, true>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, false>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
else if(data_type == "fp16")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::half_t>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
if(persistent_dp)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, true>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, false>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
if(persistent_dp)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, true>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, false>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
if(persistent_dp)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, true>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, false>, TypeConfig>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user