Introduces the new partitioner to implement the reduction StreamK kernel. (#3107)

* Introduces the new partitioner to implement the reduction StreamK kernel

* Add more doc text to functions

* Add persistent-dp option to streamk example

* Update example/ck_tile/40_streamk_gemm/README.md

[ROCm/composable_kernel commit: 5abe4109e0]
This commit is contained in:
Cong Ma
2025-11-04 10:32:17 -07:00
committed by GitHub
parent 1a8f824938
commit 0343c4e1fe
8 changed files with 298 additions and 75 deletions

View File

@@ -22,8 +22,8 @@ args:
-a_layout tensor A data layout (default: R)
-b_layout tensor B data layout (default: C)
-c_layout tensor C data layout (default: R)
-num_sk_blocks number of Stream-K blocks. -1: chosen by algorithm, or user selected (default:-1)
-reduction_strategy strategy for storing results in C tensor. atomic/reduction (default:atomic)
-persistent_dp persistent strategy for data-parallel section. Set to 0 for non-persistent or to 1 for persistent. (default:0)
-stride_a tensor A stride (default:0)
-stride_b tensor B stride (default:0)
-stride_c tensor C stride (default:0)

View File

@@ -18,7 +18,6 @@ struct GemmConfigBase
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr bool Persistent = false;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
@@ -27,12 +26,12 @@ struct GemmConfigBase
static constexpr bool DoubleSmemBuffer = false;
};
template <typename PrecType>
template <typename PrecType, bool Persistent_>
struct GemmConfigMemoryInterwave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 16;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
@@ -42,7 +41,8 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr bool Persistent = Persistent_;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <typename ADataType_, typename BDataType_ = ADataType_, typename CDataType_ = ADataType_>
@@ -96,12 +96,12 @@ auto create_args(int argc, char* argv[])
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Column by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("num_sk_blocks",
"-1",
"number of Stream-K blocks. -1: chosen by algorithm, or user selected")
.insert("reduction_strategy",
"atomic",
"strategy for storing results in C tensor - atomic/reduction")
.insert("persistent_dp",
"0",
"0. Non-persistent data-parallel section, 1 Fully persistent kernel.")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")

View File

@@ -69,20 +69,18 @@ invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup,
int n_repeat,
bool flush_cache,
ck_tile::StreamKReductionStrategy reduction_strategy,
uint32_t num_sk_blocks)
ck_tile::StreamKReductionStrategy reduction_strategy)
{
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
stride_C,
reduction_strategy,
num_sk_blocks};
ck_tile::reboot::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
stride_C,
reduction_strategy};
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
@@ -197,7 +195,6 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::StreamKReductionStrategy reduction_strategy =
get_reduction_strategy_value(arg_parser.get_str("reduction_strategy"));
uint32_t num_sk_blocks = static_cast<uint32_t>(arg_parser.get_int("num_sk_blocks"));
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
@@ -261,8 +258,7 @@ int run_gemm_example_with_layouts(int argc,
n_warmup,
n_repeat,
flush_cache,
reduction_strategy,
num_sk_blocks);
reduction_strategy);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
@@ -279,10 +275,10 @@ int run_gemm_example_with_layouts(int argc,
<< " B_Type=" << DataTypeTraits<BDataType>::name
<< " C_Type=" << DataTypeTraits<CDataType>::name
<< " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
<< " persistent_dp=" << arg_parser.get_str("persistent_dp") << " " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
bool pass = true;
bool pass = false;
// Memory on host to store gpu reference result
ck_tile::HostTensor<CDataType> c_m_n_ref(

View File

@@ -2,7 +2,6 @@
// SPDX-License-Identifier: MIT
#include "gemm_utils.hpp"
#include "run_gemm_example.inc"
#include "ck_tile/ops/common.hpp"
template <typename GemmConfig,
@@ -17,9 +16,8 @@ template <typename GemmConfig,
typename ELayout,
typename CDEElementWise,
ck_tile::StreamKReductionStrategy ReductionStrategy>
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
@@ -29,7 +27,8 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
using TilePartitioner = ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy>;
using TilePartitioner =
ck_tile::StreamKTilePartitioner_v2<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
@@ -78,9 +77,13 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
memory_operation.value,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
using Kernel = ck_tile::reboot::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
auto kargs = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
ck_tile::DeviceMem workspace_data(workspace_size);
workspace_data.SetZero();
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
dim3 grids = Kernel::GridSize(kargs.tile_partitioner);
dim3 blocks = Kernel::BlockSize();
@@ -101,28 +104,28 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
<< std::endl;
}
// Function to clear the output C tensor results after each repetition of the kernel
auto clear_gemm_output = [&]() {
auto reset_data_buffers = [&]() {
if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
{
// Clear the output C tensor results after each repetition of the kernel
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
}
else if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
{
// Reset sk flags to zero before each repetition of the kernel
workspace_data.SetZero();
}
};
std::function<void()> preprocess = clear_gemm_output;
std::function<void()> preprocess = reset_data_buffers;
float ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ck_tile::index_t num_wgs_per_tile = ck_tile::estimate_num_wgs_per_tile<ReductionStrategy>(
kargs.tile_partitioner.sk_num_blocks,
// k_iters_per_big_block could be 1, which indicates that all Stream-K workgroups are
// big and each does one iteration. Thus, we ensure the value passed in is at least 1 to
// avoid division by zero errors.
ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u),
kargs.tile_partitioner.k_iters_per_tile.get());
ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
return std::tuple{ave_time, num_wgs_per_tile};
};
@@ -145,6 +148,8 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
}
}
#include "run_gemm_example.inc"
template <typename GemmConfig, typename TypeConfig>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
@@ -164,7 +169,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
return 0;
}
template <template <typename PreType> typename GemmConfig>
template <template <typename PreType, bool Persistent_> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -174,30 +179,63 @@ int run_gemm_example(int argc, char* argv[])
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
auto persistent_dp = arg_parser.get_bool("persistent_dp");
if(data_type == "bf16")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf16_t>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
if(persistent_dp)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
}
else if(data_type == "fp16")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::half_t>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
if(persistent_dp)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
}
else if(data_type == "fp8")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
if(persistent_dp)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
}
else if(data_type == "bf8")
{
using TypeConfig = StreamKGemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>;
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>, TypeConfig>(
a_layout, b_layout, argc, argv);
if(persistent_dp)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, true>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
else
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t, false>, TypeConfig>(
a_layout, b_layout, argc, argv);
}
}
else
{

View File

@@ -110,6 +110,10 @@ CK_TILE_HOST double timing_loop_impl(TimerType timer,
{
for(int i = 0; i < s.cold_niters_; i++)
{
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
{
preprocess();
}
callables_func();
}
// Only profile preprocess if it's provided

View File

@@ -84,9 +84,10 @@ struct StreamKKernel
using CLayout = typename GemmPipeline::CLayout;
/// @brief Specify the data type configurations for A, B, and C
using ADataType = typename GemmPipeline::ADataType;
using BDataType = typename GemmPipeline::BDataType;
using CDataType = typename EpiloguePipeline::ODataType;
using ADataType = typename GemmPipeline::ADataType;
using BDataType = typename GemmPipeline::BDataType;
using CDataType = typename EpiloguePipeline::ODataType;
using AccDataType = typename EpiloguePipeline::AccDataType;
template <typename T>
static constexpr bool is_tuple_v = is_detected<is_tuple, T>::value;
@@ -243,14 +244,6 @@ struct StreamKKernel
CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs)
{
if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
}
return false;
}
return UniversalGemmKernel::IsSupportedArgument(kargs);
}
@@ -258,7 +251,7 @@ struct StreamKKernel
/// @return The buffer size needed.
CK_TILE_HOST static uint32_t GetWorkSpaceSize(const StreamKKernelArgs& kargs)
{
return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType));
return kargs.tile_partitioner.get_workspace_size(sizeof(AccDataType));
}
/// @brief Sets the kargs' current workspace_ptr to the given workspace_ptr.
@@ -299,6 +292,118 @@ struct StreamKKernel
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size);
}
/// @brief Signals that the current thread block (CTA) has completed storing its partial
/// results.
/// @param kargs Kernel arguments, including the workspace pointer.
/// @param cta_idx The index of the current thread block (CTA).
/// @note This function utilizes a workgroup barrier to set a synchronization flag for the given
/// CTA index.
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
index_t cta_idx) const
{
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
workgroup_barrier sk_flags(sk_flags_ptr);
sk_flags.wait_set(0, 1, cta_idx);
}
/// @brief Waits for the thread block (cta_idx) to complete storing its partial results.
/// @param kargs Kernel arguments, including the workspace pointer.
/// @param cta_idx The index of the thread block (CTA).
/// @note This function utilizes a workgroup barrier to wait for the synchronization flag to be
/// set by the given CTA index.
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
{
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
workgroup_barrier sk_flags(sk_flags_ptr);
sk_flags.wait_eq(1, cta_idx);
}
/// @brief Adds the values of a block tile to an output block tile.
/// @param in_out_block_tile The output block tile to which values are added.
/// @param in_block_tile The input block tile whose values are added.
/// @note This function iterates over the distributed spans of the block tiles and updates the
/// output block tile with accumulated values.
template <typename OAccTile>
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
const OAccTile& in_block_tile) const
{
using BlockType = remove_cvref_t<decltype(in_out_block_tile)>;
constexpr auto o_spans = BlockType::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto idx = make_tuple(idx0, idx1);
in_out_block_tile(idx) = in_out_block_tile[idx] + in_block_tile[idx];
});
});
}
/// @brief Loads a partial block tile from the workspace buffer.
/// @param kargs Kernel arguments, including the workspace pointer.
/// @param cta_idx The index of the thread block (CTA).
/// @param c_block_tile_dist The tile distribution for the block.
/// @return The loaded partial block tile.
/// @note This function calculates the buffer pointer and uses the tile distribution for loading
/// the partial block tile.
template <typename DataType, typename OAccTileDist>
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
index_t cta_idx,
const OAccTileDist& c_block_tile_dist) const
{
const auto c_block_tile_buffer_size =
TilePartitioner::MPerBlock * TilePartitioner::NPerBlock * sizeof(DataType);
void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
kargs.tile_partitioner.get_flags_buffer_size() +
cta_idx * c_block_tile_buffer_size;
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
static_cast<DataType*>(partial_buffer_ptr),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
make_tuple(TilePartitioner::NPerBlock, 1),
number<GemmPipeline::GetVectorSizeC()>{},
number<1>{});
auto partial_tile_window = make_tile_window(
partial_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{0, 0},
c_block_tile_dist);
return load_tile(partial_tile_window);
}
/// @brief Stores a partial block tile to the workspace buffer.
/// @param kargs Kernel arguments, including the workspace pointer.
/// @param cta_idx The index of the thread block (CTA).
/// @param c_block_tile The block tile to be stored.
/// @note This function calculates the buffer pointer and uses the tile window for storing the
/// partial block tile.
template <typename OAccTile>
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
index_t cta_idx,
const OAccTile& c_block_tile) const
{
const auto c_block_tile_buffer_size = TilePartitioner::MPerBlock *
TilePartitioner::NPerBlock *
sizeof(typename OAccTile::DataType);
void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
kargs.tile_partitioner.get_flags_buffer_size() +
cta_idx * c_block_tile_buffer_size;
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
make_tuple(TilePartitioner::NPerBlock, 1),
number<GemmPipeline::GetVectorSizeC()>{},
number<1>{});
auto partial_tile_window = make_tile_window(
partial_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{0, 0});
store_tile(partial_tile_window, c_block_tile);
}
/// @brief Runs the main Stream-K algorithm.
/// @param kargs Stream-K kernel arguments.
/// @param cta_idx The current Stream-K workgroup's index.
@@ -347,7 +452,88 @@ struct StreamKKernel
}
else
{
// TODO: Apply reduction logic.
const auto c_macro_tile_idx =
kargs.tile_partitioner.get_output_tile_index(tile_idx);
index_t i_m =
c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
index_t i_n =
c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
UniversalGemmKernel::template MakeGemmTensorViews<
EpiloguePipeline::MemoryOperation>(
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, k_size);
const auto& gemm_pad_views =
UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, i_m, i_n);
// Run GEMM cooperatively by whole workgroup.
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
// Since num_loop can vary per WG and per iteration of the Stream-K while loop,
// we compute has_hot_loop and tail_num here. This is a similar pattern used by
// grouped GEMM. In this case, we call the GemmPipeline's operator() function
// that takes both has_hot_loop and tail_num.
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk);
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
bs_block_window[UniversalGemmKernel::I0],
num_loop_sk,
has_hot_loop,
tail_num,
smem_ptr_0);
auto tile_started = iter_start == tile_iter_start;
auto tile_ended = iter_end >= tile_iter_end;
if(!tile_started)
{
StorePartial(kargs, cta_idx, c_block_tile);
// Ensure device-wide visibility of partial results stored in global memory
// before signaling completion. __threadfence() guarantees that all global
// memory writes by this thread are visible to other threads on the device.
__threadfence(); // send signal when the store is done
SignalStorePartialDone(kargs, cta_idx);
}
else
{
auto accum_block_tile = c_block_tile;
if(!tile_ended)
{
const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
int accum_iters = local_iter_end - local_iter_start;
int next_cta = cta_idx + 1;
while(accum_iters < iter_per_tile)
{
WaitStorePartialDone(kargs, next_cta);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(
accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs, next_cta, c_block_tile.get_tile_distribution()));
accum_iters += iter_per_cta + (next_cta < extra_iters);
++next_cta;
}
}
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
}
}
// Prepare for next Stream-K loop iteration.

View File

@@ -31,21 +31,20 @@ struct StreamKTilePartitionerBase
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);
private:
/**
* @brief Calculates the total space needed for the partials buffer.
*
* @param acc_element_bytes The number of bytes for the accumulator data type used in the GEMM.
* @return index_t The number of bytes needed for the partials buffer.
*/
CK_TILE_HOST index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept;
CK_TILE_HOST_DEVICE index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept;
/**
* @brief Calculates the total space needed for the flags buffer.
*
* @return index_t The number of bytes needed for the flags buffer.
*/
CK_TILE_HOST index_t get_flags_buffer_size() const noexcept;
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept;
public:
/**
@@ -123,7 +122,7 @@ struct StreamKTilePartitionerBase
* @param acc_element_bytes The number of bytes for the accumulator data type used in the GEMM.
* @return index_t The number of bytes needed for the partials and flags buffers.
*/
CK_TILE_HOST index_t get_workspace_size(index_t acc_element_bytes) const noexcept;
CK_TILE_HOST_DEVICE index_t get_workspace_size(index_t acc_element_bytes) const noexcept;
/**
* @brief Returns the number of macro tiles in the C tensor.

View File

@@ -45,7 +45,7 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::StreamKTi
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_HOST index_t
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_partials_buffer_size(
index_t acc_element_bytes) const noexcept
{
@@ -53,7 +53,7 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_parti
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_HOST index_t
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags_buffer_size()
const noexcept
{
@@ -116,7 +116,7 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_outpu
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_HOST index_t
CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_workspace_size(
index_t acc_element_bytes) const noexcept
{