mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Merge commit '2723dbd33245b76bfe716c5adc8c9fb577a4b68f' into develop
This commit is contained in:
@@ -887,6 +887,58 @@ struct tile_window_with_static_lengths
|
||||
this->window_lengths_ = window_lengths;
|
||||
this->bottom_tensor_view_ = bottom_tensor_view;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Print tile window elements for debugging.
|
||||
*
|
||||
* @tparam DataType Element data type (e.g., fp16_t, float, bf8_t)
|
||||
* @param start_i Starting row (inclusive)
|
||||
* @param end_i Ending row (exclusive)
|
||||
* @param start_j Starting column (inclusive)
|
||||
* @param end_j Ending column (exclusive)
|
||||
* @param label Optional output label
|
||||
*
|
||||
* @note Tested on fp16. Custom types may need adjustments.
|
||||
* @example tile_window.template print_tile_window_range<fp16_t>(0, 4, 0, 8, "A");
|
||||
*/
|
||||
template <typename DataType>
|
||||
CK_TILE_DEVICE void print_tile_window_range(index_t start_i,
|
||||
index_t end_i,
|
||||
index_t start_j,
|
||||
index_t end_j,
|
||||
const char* label = "") const
|
||||
{
|
||||
const auto& tensor_view = this->get_bottom_tensor_view();
|
||||
const auto window_origin = this->get_window_origin();
|
||||
|
||||
printf("%s Window Range [%d:%d, %d:%d] (origin: %d, %d):\n",
|
||||
label,
|
||||
start_i,
|
||||
end_i - 1,
|
||||
start_j,
|
||||
end_j - 1,
|
||||
window_origin[0],
|
||||
window_origin[1]);
|
||||
|
||||
for(index_t i = start_i; i < end_i; i++)
|
||||
{
|
||||
for(index_t j = start_j; j < end_j; j++)
|
||||
{
|
||||
// Create coordinate for this element relative to window origin
|
||||
auto coord =
|
||||
make_tensor_coordinate(tensor_view.get_tensor_descriptor(),
|
||||
make_tuple(window_origin[0] + i, window_origin[1] + j));
|
||||
|
||||
// Get the element using thread buffer type directly
|
||||
using ThreadBuf = thread_buffer<DataType, 2>;
|
||||
auto buf = tensor_view.template get_vectorized_elements<ThreadBuf>(coord, 0);
|
||||
auto value = buf.at(number<0>{}); // Extract first element from thread buffer
|
||||
printf(" %s[%d,%d] = %f", label, i, j, static_cast<float>(value));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TensorView_, typename WindowLengths_>
|
||||
|
||||
@@ -646,16 +646,13 @@ struct StreamKTilePartitioner
|
||||
* @brief Get length of loop iterations for stream-k loop
|
||||
*/
|
||||
CK_TILE_DEVICE uint32_t GetCurrentIterLength(uint32_t iter_start,
|
||||
uint32_t iter_end,
|
||||
uint32_t total_iter_length) const noexcept
|
||||
uint32_t iter_end) const noexcept
|
||||
{
|
||||
uint32_t iter_length_mod, iter_length_quo /*unused*/;
|
||||
k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod);
|
||||
uint32_t total_iter_length_val = static_cast<uint32_t>(total_iter_length);
|
||||
uint32_t current_iter_length =
|
||||
min(iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod,
|
||||
total_iter_length_val);
|
||||
return current_iter_length;
|
||||
// A WG's iter_end is either in the current C macro tile or not.
|
||||
// If it is not, then the macro tile boundary is where the WG must stop.
|
||||
uint32_t distance_to_tile_boundary =
|
||||
k_iters_per_tile.get() - (iter_start % k_iters_per_tile.get());
|
||||
return min(iter_start + distance_to_tile_boundary, iter_end) - iter_start;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -672,9 +669,7 @@ struct StreamKTilePartitioner
|
||||
CK_TILE_DEVICE void
|
||||
GetTileIdxWithOffset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const noexcept
|
||||
{
|
||||
uint32_t tile_idx_val = static_cast<uint32_t>(tile_idx);
|
||||
uint32_t iter_offset_val = static_cast<uint32_t>(iter_offset);
|
||||
k_iters_per_tile.divmod(iter, tile_idx_val, iter_offset_val);
|
||||
k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -374,7 +374,7 @@ struct GroupedGemmKernel
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset);
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
@@ -436,7 +436,7 @@ struct GroupedGemmKernel
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset);
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
|
||||
@@ -141,11 +141,17 @@ struct StreamKKernel
|
||||
return UniversalGemmKernel::BlockSize();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args)
|
||||
/// @brief Constructs kernel arguments for the Stream-K kernel.
|
||||
/// @param host_args Stream-K host arguments.
|
||||
/// @param num_cu Number of compute units (CUs). The default is the number of CUs on the device.
|
||||
/// The caller may select their own to assist with test reproducibility, etc.
|
||||
/// @param occupancy The maximum number of active blocks per CU for this kernel. The caller may
|
||||
/// select their own to assist with test reproducibility, etc.
|
||||
/// @return The kernel arguments for Stream-K.
|
||||
CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args,
|
||||
int num_cu = NumCU(),
|
||||
int occupancy = Occupancy())
|
||||
{
|
||||
uint32_t occupancy = static_cast<uint32_t>(Occupancy());
|
||||
uint32_t num_cu = static_cast<uint32_t>(NumCU());
|
||||
|
||||
return StreamKKernelArgs{{host_args.as_ptr,
|
||||
host_args.bs_ptr,
|
||||
host_args.ds_ptr,
|
||||
@@ -166,14 +172,71 @@ struct StreamKKernel
|
||||
TilePartitioner{static_cast<uint32_t>(host_args.M),
|
||||
static_cast<uint32_t>(host_args.N),
|
||||
static_cast<uint32_t>(host_args.K),
|
||||
num_cu,
|
||||
occupancy,
|
||||
static_cast<uint32_t>(num_cu),
|
||||
static_cast<uint32_t>(occupancy),
|
||||
host_args.num_sk_blocks}};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool
|
||||
IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs)
|
||||
template <bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void
|
||||
RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
|
||||
const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const typename UniversalGemmKernel::KernelArgs& kargs,
|
||||
const index_t num_loop,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n,
|
||||
const index_t k_size)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
|
||||
|
||||
const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
|
||||
|
||||
// Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
|
||||
// has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
|
||||
// case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
|
||||
// tail_num.
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
|
||||
bs_block_window[UniversalGemmKernel::I0],
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs)
|
||||
{
|
||||
if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
@@ -199,9 +262,81 @@ struct StreamKKernel
|
||||
kargs.workspace_ptr = workspace_ptr;
|
||||
}
|
||||
|
||||
// Temporary placeholder to support the Occupancy() static function.
|
||||
// Since the Occupancy function uses kentry, this class must have an operator() function
|
||||
CK_TILE_DEVICE void operator()(StreamKKernelArgs /*kargs*/) const {}
|
||||
/// @brief Entry point for the Stream-K Kernel, performing the main Stream-K loop.
|
||||
CK_TILE_DEVICE void operator()(StreamKKernelArgs kargs) const
|
||||
{
|
||||
// Allocate LDS
|
||||
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
|
||||
|
||||
uint32_t block_idx = ck_tile::get_block_1d_id();
|
||||
|
||||
bool is_padding_block =
|
||||
__builtin_amdgcn_readfirstlane(block_idx >= kargs.tile_partitioner.sk_num_blocks &&
|
||||
block_idx < kargs.tile_partitioner.dp_start_block_idx);
|
||||
|
||||
// Padding blocks make it such that the DP blocks are aligned with the number of CUs; they
|
||||
// should not partake in the GEMM
|
||||
if(is_padding_block)
|
||||
return;
|
||||
|
||||
// Determine the K offset of the first and final macro tile in the A and B tensors along the
|
||||
// K dimension.
|
||||
uint32_t iter_start, iter_end;
|
||||
kargs.tile_partitioner.GetBlockItr(block_idx, iter_start, iter_end);
|
||||
|
||||
// Main Stream-K loop
|
||||
while(true)
|
||||
{
|
||||
// Determine the number of macro tiles in A and B this WG is resposible for in the
|
||||
// current C macro tile.
|
||||
uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
|
||||
kargs.tile_partitioner.GetCurrentIterLength(iter_start, iter_end));
|
||||
|
||||
// Determine the 1D tile_idx and the iter_offset for this WG.
|
||||
// The tile_idx is the 1D macro tile index in the C tensor.
|
||||
// The iter_offset is the starting macro tile index in the K dimension for the WG in the
|
||||
// current iteration of the while loop.
|
||||
uint32_t tile_idx, iter_offset;
|
||||
kargs.tile_partitioner.GetTileIdxWithOffset(iter_start, tile_idx, iter_offset);
|
||||
|
||||
// Get the 2D tile index in the C tensor for this WG using the 1D index (i.e. tile_idx)
|
||||
auto spatial_idx = kargs.tile_partitioner.GetOutputTileIndex(tile_idx);
|
||||
|
||||
// Get the offsets in A, B, C tensors.
|
||||
index_t i_m = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I0] *
|
||||
TilePartitioner::MPerBlock);
|
||||
index_t i_n = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I1] *
|
||||
TilePartitioner::NPerBlock);
|
||||
index_t i_k = static_cast<index_t>(iter_offset) * TilePartitioner::KPerBlock;
|
||||
|
||||
// Determine the total size along the K dimension the WG is using in this iteration
|
||||
// (used to construct tensor views).
|
||||
index_t k_size = static_cast<index_t>(current_iter_length * TilePartitioner::KPerBlock);
|
||||
|
||||
// Update pointer offsets for A, B, and C.
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k;
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
|
||||
|
||||
// Run the GEMM pipeline and Epilogue.
|
||||
RunGemm({a_ptr},
|
||||
{b_ptr},
|
||||
{/*ds_ptr*/},
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
current_iter_length,
|
||||
i_m,
|
||||
i_n,
|
||||
k_size);
|
||||
|
||||
// Prepare for next Stream-K loop iteration.
|
||||
iter_start += current_iter_length;
|
||||
if(iter_end <= iter_start)
|
||||
break;
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
CK_TILE_HOST static int NumCU()
|
||||
|
||||
@@ -579,7 +579,7 @@ struct UniversalGemmKernel
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
const index_t k_size)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
|
||||
@@ -591,7 +591,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const AiDataType*>(as_ptr[i]),
|
||||
make_tuple(kargs.M, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.M, k_size),
|
||||
make_tuple(kargs.stride_As[i], 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
@@ -600,7 +600,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const AiDataType*>(as_ptr[i]),
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.M),
|
||||
make_tuple(k_size, kargs.M),
|
||||
make_tuple(kargs.stride_As[i], 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
@@ -617,7 +617,7 @@ struct UniversalGemmKernel
|
||||
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = splitk_batch_offset.splitted_k / K1;
|
||||
const index_t K0 = k_size / K1;
|
||||
constexpr index_t VectorSizeB =
|
||||
std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
@@ -638,7 +638,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bs_ptr[i],
|
||||
make_tuple(splitk_batch_offset.splitted_k, kargs.N),
|
||||
make_tuple(k_size, kargs.N),
|
||||
make_tuple(kargs.stride_Bs[i], 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
@@ -649,7 +649,7 @@ struct UniversalGemmKernel
|
||||
if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
|
||||
{
|
||||
constexpr index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const index_t K0 = splitk_batch_offset.splitted_k / K1;
|
||||
const index_t K0 = k_size / K1;
|
||||
constexpr index_t VectorSizeB =
|
||||
std::min(K1, GemmPipeline::GetVectorSizeB());
|
||||
const auto b_k0_n_k1_desc =
|
||||
@@ -672,7 +672,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
index_t kFlatK =
|
||||
GemmPipeline::BlockGemmShape::flatKPerWarp *
|
||||
(splitk_batch_offset.splitted_k /
|
||||
(k_size /
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
@@ -687,7 +687,7 @@ struct UniversalGemmKernel
|
||||
{
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
bs_ptr[i],
|
||||
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
|
||||
make_tuple(kargs.N, k_size),
|
||||
make_tuple(kargs.stride_Bs[i], 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
@@ -962,7 +962,7 @@ struct UniversalGemmKernel
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
@@ -1018,7 +1018,7 @@ struct UniversalGemmKernel
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
|
||||
as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
Reference in New Issue
Block a user