mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Merge commit '5abe4109e0c30993b9e1afe00f95154939043859' into develop
This commit is contained in:
@@ -110,6 +110,10 @@ CK_TILE_HOST double timing_loop_impl(TimerType timer,
|
||||
{
|
||||
for(int i = 0; i < s.cold_niters_; i++)
|
||||
{
|
||||
if constexpr(!std::is_same_v<PreprocessFunc, std::nullptr_t>)
|
||||
{
|
||||
preprocess();
|
||||
}
|
||||
callables_func();
|
||||
}
|
||||
// Only profile preprocess if it's provided
|
||||
|
||||
@@ -84,9 +84,10 @@ struct StreamKKernel
|
||||
using CLayout = typename GemmPipeline::CLayout;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, and C
|
||||
using ADataType = typename GemmPipeline::ADataType;
|
||||
using BDataType = typename GemmPipeline::BDataType;
|
||||
using CDataType = typename EpiloguePipeline::ODataType;
|
||||
using ADataType = typename GemmPipeline::ADataType;
|
||||
using BDataType = typename GemmPipeline::BDataType;
|
||||
using CDataType = typename EpiloguePipeline::ODataType;
|
||||
using AccDataType = typename EpiloguePipeline::AccDataType;
|
||||
|
||||
template <typename T>
|
||||
static constexpr bool is_tuple_v = is_detected<is_tuple, T>::value;
|
||||
@@ -243,14 +244,6 @@ struct StreamKKernel
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs)
|
||||
{
|
||||
if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
@@ -258,7 +251,7 @@ struct StreamKKernel
|
||||
/// @return The buffer size needed.
|
||||
CK_TILE_HOST static uint32_t GetWorkSpaceSize(const StreamKKernelArgs& kargs)
|
||||
{
|
||||
return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType));
|
||||
return kargs.tile_partitioner.get_workspace_size(sizeof(AccDataType));
|
||||
}
|
||||
|
||||
/// @brief Sets the kargs' current workspace_ptr to the given workspace_ptr.
|
||||
@@ -299,6 +292,118 @@ struct StreamKKernel
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size);
|
||||
}
|
||||
|
||||
/// @brief Signals that the current thread block (CTA) has completed storing its partial
|
||||
/// results.
|
||||
/// @param kargs Kernel arguments, including the workspace pointer.
|
||||
/// @param cta_idx The index of the current thread block (CTA).
|
||||
/// @note This function utilizes a workgroup barrier to set a synchronization flag for the given
|
||||
/// CTA index.
|
||||
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
|
||||
index_t cta_idx) const
|
||||
{
|
||||
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
|
||||
workgroup_barrier sk_flags(sk_flags_ptr);
|
||||
sk_flags.wait_set(0, 1, cta_idx);
|
||||
}
|
||||
|
||||
/// @brief Waits for the thread block (cta_idx) to complete storing its partial results.
|
||||
/// @param kargs Kernel arguments, including the workspace pointer.
|
||||
/// @param cta_idx The index of the thread block (CTA).
|
||||
/// @note This function utilizes a workgroup barrier to wait for the synchronization flag to be
|
||||
/// set by the given CTA index.
|
||||
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
|
||||
{
|
||||
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
|
||||
workgroup_barrier sk_flags(sk_flags_ptr);
|
||||
sk_flags.wait_eq(1, cta_idx);
|
||||
}
|
||||
|
||||
/// @brief Adds the values of a block tile to an output block tile.
|
||||
/// @param in_out_block_tile The output block tile to which values are added.
|
||||
/// @param in_block_tile The input block tile whose values are added.
|
||||
/// @note This function iterates over the distributed spans of the block tiles and updates the
|
||||
/// output block tile with accumulated values.
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
|
||||
const OAccTile& in_block_tile) const
|
||||
{
|
||||
using BlockType = remove_cvref_t<decltype(in_out_block_tile)>;
|
||||
constexpr auto o_spans = BlockType::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto idx = make_tuple(idx0, idx1);
|
||||
in_out_block_tile(idx) = in_out_block_tile[idx] + in_block_tile[idx];
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/// @brief Loads a partial block tile from the workspace buffer.
|
||||
/// @param kargs Kernel arguments, including the workspace pointer.
|
||||
/// @param cta_idx The index of the thread block (CTA).
|
||||
/// @param c_block_tile_dist The tile distribution for the block.
|
||||
/// @return The loaded partial block tile.
|
||||
/// @note This function calculates the buffer pointer and uses the tile distribution for loading
|
||||
/// the partial block tile.
|
||||
template <typename DataType, typename OAccTileDist>
|
||||
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
|
||||
index_t cta_idx,
|
||||
const OAccTileDist& c_block_tile_dist) const
|
||||
{
|
||||
const auto c_block_tile_buffer_size =
|
||||
TilePartitioner::MPerBlock * TilePartitioner::NPerBlock * sizeof(DataType);
|
||||
void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
|
||||
kargs.tile_partitioner.get_flags_buffer_size() +
|
||||
cta_idx * c_block_tile_buffer_size;
|
||||
|
||||
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<DataType*>(partial_buffer_ptr),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
make_tuple(TilePartitioner::NPerBlock, 1),
|
||||
number<GemmPipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
|
||||
auto partial_tile_window = make_tile_window(
|
||||
partial_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{0, 0},
|
||||
c_block_tile_dist);
|
||||
|
||||
return load_tile(partial_tile_window);
|
||||
}
|
||||
|
||||
/// @brief Stores a partial block tile to the workspace buffer.
|
||||
/// @param kargs Kernel arguments, including the workspace pointer.
|
||||
/// @param cta_idx The index of the thread block (CTA).
|
||||
/// @param c_block_tile The block tile to be stored.
|
||||
/// @note This function calculates the buffer pointer and uses the tile window for storing the
|
||||
/// partial block tile.
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
|
||||
index_t cta_idx,
|
||||
const OAccTile& c_block_tile) const
|
||||
{
|
||||
const auto c_block_tile_buffer_size = TilePartitioner::MPerBlock *
|
||||
TilePartitioner::NPerBlock *
|
||||
sizeof(typename OAccTile::DataType);
|
||||
void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
|
||||
kargs.tile_partitioner.get_flags_buffer_size() +
|
||||
cta_idx * c_block_tile_buffer_size;
|
||||
|
||||
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
make_tuple(TilePartitioner::NPerBlock, 1),
|
||||
number<GemmPipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
|
||||
auto partial_tile_window = make_tile_window(
|
||||
partial_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{0, 0});
|
||||
|
||||
store_tile(partial_tile_window, c_block_tile);
|
||||
}
|
||||
|
||||
/// @brief Runs the main Stream-K algorithm.
|
||||
/// @param kargs Stream-K kernel arguments.
|
||||
/// @param cta_idx The current Stream-K workgroup's index.
|
||||
@@ -347,7 +452,88 @@ struct StreamKKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: Apply reduction logic.
|
||||
const auto c_macro_tile_idx =
|
||||
kargs.tile_partitioner.get_output_tile_index(tile_idx);
|
||||
index_t i_m =
|
||||
c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
|
||||
index_t i_n =
|
||||
c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
|
||||
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
|
||||
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
UniversalGemmKernel::template MakeGemmTensorViews<
|
||||
EpiloguePipeline::MemoryOperation>(
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, k_size);
|
||||
|
||||
const auto& gemm_pad_views =
|
||||
UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, i_m, i_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
|
||||
|
||||
// Since num_loop can vary per WG and per iteration of the Stream-K while loop,
|
||||
// we compute has_hot_loop and tail_num here. This is a similar pattern used by
|
||||
// grouped GEMM. In this case, we call the GemmPipeline's operator() function
|
||||
// that takes both has_hot_loop and tail_num.
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
|
||||
bs_block_window[UniversalGemmKernel::I0],
|
||||
num_loop_sk,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
|
||||
auto tile_started = iter_start == tile_iter_start;
|
||||
auto tile_ended = iter_end >= tile_iter_end;
|
||||
if(!tile_started)
|
||||
{
|
||||
StorePartial(kargs, cta_idx, c_block_tile);
|
||||
// Ensure device-wide visibility of partial results stored in global memory
|
||||
// before signaling completion. __threadfence() guarantees that all global
|
||||
// memory writes by this thread are visible to other threads on the device.
|
||||
__threadfence(); // send signal when the store is done
|
||||
SignalStorePartialDone(kargs, cta_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto accum_block_tile = c_block_tile;
|
||||
if(!tile_ended)
|
||||
{
|
||||
const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
|
||||
const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
|
||||
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
|
||||
int accum_iters = local_iter_end - local_iter_start;
|
||||
int next_cta = cta_idx + 1;
|
||||
|
||||
while(accum_iters < iter_per_tile)
|
||||
{
|
||||
WaitStorePartialDone(kargs, next_cta);
|
||||
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
AddBlockTile(
|
||||
accum_block_tile,
|
||||
LoadPartial<typename BlockType::DataType>(
|
||||
kargs, next_cta, c_block_tile.get_tile_distribution()));
|
||||
|
||||
accum_iters += iter_per_cta + (next_cta < extra_iters);
|
||||
++next_cta;
|
||||
}
|
||||
}
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare for next Stream-K loop iteration.
|
||||
|
||||
@@ -31,21 +31,20 @@ struct StreamKTilePartitionerBase
|
||||
|
||||
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Calculates the total space needed for the partials buffer.
|
||||
*
|
||||
* @param acc_element_bytes The number of bytes for the accumulator data type used in the GEMM.
|
||||
* @return index_t The number of bytes needed for the partials buffer.
|
||||
*/
|
||||
CK_TILE_HOST index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept;
|
||||
CK_TILE_HOST_DEVICE index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the total space needed for the flags buffer.
|
||||
*
|
||||
* @return index_t The number of bytes needed for the flags buffer.
|
||||
*/
|
||||
CK_TILE_HOST index_t get_flags_buffer_size() const noexcept;
|
||||
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept;
|
||||
|
||||
public:
|
||||
/**
|
||||
@@ -123,7 +122,7 @@ struct StreamKTilePartitionerBase
|
||||
* @param acc_element_bytes The number of bytes for the accumulator data type used in the GEMM.
|
||||
* @return index_t The number of bytes needed for the partials and flags buffers.
|
||||
*/
|
||||
CK_TILE_HOST index_t get_workspace_size(index_t acc_element_bytes) const noexcept;
|
||||
CK_TILE_HOST_DEVICE index_t get_workspace_size(index_t acc_element_bytes) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the number of macro tiles in the C tensor.
|
||||
|
||||
@@ -45,7 +45,7 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::StreamKTi
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST index_t
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_partials_buffer_size(
|
||||
index_t acc_element_bytes) const noexcept
|
||||
{
|
||||
@@ -53,7 +53,7 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_parti
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST index_t
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags_buffer_size()
|
||||
const noexcept
|
||||
{
|
||||
@@ -116,7 +116,7 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_outpu
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST index_t
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_workspace_size(
|
||||
index_t acc_element_bytes) const noexcept
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user