mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
Implement RunGemmDesc that allows directly passing descriptors
This commit is contained in:
@@ -989,6 +989,78 @@ struct UniversalGemmKernel
|
||||
}
|
||||
}
|
||||
|
||||
// Version of RunGemm using descriptors
|
||||
template <typename AGridDesc,
|
||||
typename BGridDesc,
|
||||
typename EGridDesc,
|
||||
bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void RunGemmDesc(const std::array<const ADataType*, NumATensor>& as_ptr,
|
||||
const std::array<const BDataType*, NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr_0,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n,
|
||||
const std::array<AGridDesc, NumATensor>& as_desc,
|
||||
const std::array<BGridDesc, NumBTensor>& bs_desc,
|
||||
const std::array<EGridDesc, NumDTensor>& ds_desc,
|
||||
const EGridDesc& e_desc)
|
||||
{
|
||||
// Create tensor views from descriptors (supports arbitrary stride patterns)
|
||||
const auto& as_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using AiDataType = remove_cvref_t<std::tuple_element_t<i.value, AsDataType>>;
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const AiDataType*>(as_ptr[i]), as_desc[i]);
|
||||
},
|
||||
number<NumATensor>{});
|
||||
|
||||
const auto& bs_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using BiDataType = remove_cvref_t<std::tuple_element_t<i.value, BsDataType>>;
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const BiDataType*>(bs_ptr[i]), bs_desc[i]);
|
||||
},
|
||||
number<NumBTensor>{});
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DiDataType = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const DiDataType*>(ds_ptr[i]), ds_desc[i]);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
auto e_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(static_cast<EDataType*>(e_ptr), e_desc);
|
||||
|
||||
const auto& gemm_tensors_views_tuple =
|
||||
make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensors_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop =
|
||||
amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0);
|
||||
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
|
||||
Reference in New Issue
Block a user