Implement RunGemmDesc that allows directly passing descriptors

This commit is contained in:
Matti Eskelinen
2025-12-16 12:12:52 +00:00
parent 700b2ec9c0
commit 96820bf5a8

View File

@@ -989,6 +989,78 @@ struct UniversalGemmKernel
}
}
// Version of RunGemm using descriptors
template <typename AGridDesc,
typename BGridDesc,
typename EGridDesc,
bool UseDefaultScheduler = true>
CK_TILE_DEVICE static void RunGemmDesc(const std::array<const ADataType*, NumATensor>& as_ptr,
const std::array<const BDataType*, NumBTensor>& bs_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
EDataType* e_ptr,
void* smem_ptr_0,
const SplitKBatchOffset& splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n,
const std::array<AGridDesc, NumATensor>& as_desc,
const std::array<BGridDesc, NumBTensor>& bs_desc,
const std::array<EGridDesc, NumDTensor>& ds_desc,
const EGridDesc& e_desc)
{
// Create tensor views from descriptors (supports arbitrary stride patterns)
const auto& as_tensor_view = generate_tuple(
[&](auto i) {
using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
return make_tensor_view<address_space_enum::global>(
static_cast<const AiDataType*>(as_ptr[i]), as_desc[i]);
},
number<NumATensor>{});
const auto& bs_tensor_view = generate_tuple(
[&](auto i) {
using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
return make_tensor_view<address_space_enum::global>(
static_cast<const BiDataType*>(bs_ptr[i]), bs_desc[i]);
},
number<NumBTensor>{});
const auto& ds_tensor_view = generate_tuple(
[&](auto i) {
using DiDataType = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
return make_tensor_view<address_space_enum::global>(
static_cast<const DiDataType*>(ds_ptr[i]), ds_desc[i]);
},
number<NumDTensor>{});
auto e_tensor_view =
make_tensor_view<address_space_enum::global>(static_cast<EDataType*>(e_ptr), e_desc);
const auto& gemm_tensors_views_tuple =
make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensors_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop =
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
// Run GEMM cooperatively by whole workgroup.
const auto& as_block_window = gemm_tile_windows.at(I0);
const auto& bs_block_window = gemm_tile_windows.at(I1);
const auto& ds_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*