mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Streamk functional tests (#2974)
* Add initial fp16_mem_128x128x32_2x2x1_32x32x16_NonPersistent test suite * Account for stride when computing K offsets for A and B tensor This change ensures that the correct stride is used when computing the K offsets into the A and B tensors in the Stream-K Kernel's operator() function. This ensures that the kernel executes correct regardless of whether A and B are row or column major. * Move helper code to test_gemm_streamk_util.hpp * Separate tests into smoke/regression/extended. Add bf16 datatype * Run clang-format * Refactor combinatorial macro expansion and naming * Adjust the initialization values to account for better tolerance on bf16 * Correct BF16 datatypes in comments * Move the extended tests under the REGRESSION_TESTS label * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Emily Martins <emily.martins@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
0843815db7
commit
f5708882a3
@@ -303,19 +303,20 @@ struct StreamKKernel
|
||||
auto spatial_idx = kargs.tile_partitioner.GetOutputTileIndex(tile_idx);
|
||||
|
||||
// Get the offsets in A, B, C tensors.
|
||||
index_t i_m = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I0] *
|
||||
index_t i_m = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I0] *
|
||||
TilePartitioner::MPerBlock);
|
||||
index_t i_n = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I1] *
|
||||
index_t i_n = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I1] *
|
||||
TilePartitioner::NPerBlock);
|
||||
index_t i_k = static_cast<index_t>(iter_offset) * TilePartitioner::KPerBlock;
|
||||
auto [i_k_a, i_k_b] = GetKOffsets<ALayout, BLayout>(
|
||||
static_cast<index_t>(iter_offset), kargs.stride_As[0], kargs.stride_Bs[0]);
|
||||
|
||||
// Determine the total size along the K dimension the WG is using in this iteration
|
||||
// (used to construct tensor views).
|
||||
index_t k_size = static_cast<index_t>(current_iter_length * TilePartitioner::KPerBlock);
|
||||
|
||||
// Update pointer offsets for A, B, and C.
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k;
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k;
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
|
||||
|
||||
// Run the GEMM pipeline and Epilogue.
|
||||
@@ -339,6 +340,41 @@ struct StreamKKernel
|
||||
}
|
||||
|
||||
private:
|
||||
/// @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
|
||||
/// the starting macro tile index in the K dimension for the workgroup.
|
||||
/// @return A tuple containing the offsets into the A and B tensors accounting for the layouts
|
||||
/// of A and B.
|
||||
/// @note The default case is that A is assumed to be row major and B is assumed to be column
|
||||
/// major.
|
||||
template <typename ALayout, typename BLayout>
|
||||
CK_TILE_DEVICE static tuple<index_t, index_t>
|
||||
GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b)
|
||||
{
|
||||
index_t stride_offset_a;
|
||||
index_t stride_offset_b;
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
stride_offset_a = stride_a;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_offset_a = 1;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
stride_offset_b = stride_b;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_offset_b = 1;
|
||||
}
|
||||
|
||||
index_t base_offset = iter_offset * TilePartitioner::KPerBlock;
|
||||
|
||||
return make_tuple(base_offset * stride_offset_a, base_offset * stride_offset_b);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static int NumCU()
|
||||
{
|
||||
hipDeviceProp_t dev_prop;
|
||||
|
||||
Reference in New Issue
Block a user