Wmma support for multiple ABD GEMM (#2803)

* multi_abd wmma support:

 - Add multiple A and B support to multiple D implementation (gridwise level)
 - Add multi_abd GEMM (device level)
 - Add instances (xdl parity)
 - Add tests (both xdl and wmma)
 - Add examples
 - Add ckProfiler support (both xdl and wmma)

* Fix bug in device print function

* Fix unused template parameter

* Fix batched gemm for multiABD gridwise implementation

* Fix gemm_universal_reduce with multiABDs gridwise implementation

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Enrico Degregori
2025-09-23 03:49:06 +02:00
committed by GitHub
parent de47ae2fdf
commit 3d29bff2f0
38 changed files with 5343 additions and 312 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -15,6 +15,151 @@
namespace ck {
namespace utility {
template <typename Argument, typename AsDataType, typename BsDataType, typename DsDataType>
struct RotatingMemWrapperMultiABD
{
static constexpr index_t NumAs = AsDataType::Size();
static constexpr index_t NumBs = BsDataType::Size();
static constexpr index_t NumDs = DsDataType::Size();
using AsGridPointer = decltype(Argument::p_as_grid);
using BsGridPointer = decltype(Argument::p_bs_grid);
using DsGridPointer = decltype(Argument::p_ds_grid);
RotatingMemWrapperMultiABD() = delete;
RotatingMemWrapperMultiABD(Argument& arg_,
std::size_t rotating_count_,
std::array<std::size_t, NumAs> size_as_,
std::array<std::size_t, NumBs> size_bs_,
std::array<std::size_t, NumDs> size_ds_)
: arg(arg_),
rotating_count(rotating_count_),
size_as(size_as_),
size_bs(size_bs_),
size_ds(size_ds_)
{
p_as_grids.push_back(arg.p_as_grid);
p_bs_grids.push_back(arg.p_bs_grid);
p_ds_grids.push_back(arg.p_ds_grid);
for(size_t i = 1; i < rotating_count; i++)
{
{
AsGridPointer as_buffer;
static_for<0, NumAs, 1>{}([&](auto j) {
void* pADeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_as_[j]));
hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
static_cast<const void*>(p_as_grids[0][j]),
size_as_[j],
hipMemcpyDeviceToDevice));
using ADataType = remove_cvref_t<tuple_element_t<j.value, AsDataType>>;
as_buffer(j) = static_cast<const ADataType*>(pADeviceBuf);
});
p_as_grids.push_back(as_buffer);
}
{
BsGridPointer bs_buffer;
static_for<0, NumBs, 1>{}([&](auto j) {
void* pBDeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_bs_[j]));
hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
static_cast<const void*>(p_bs_grids[0][j]),
size_bs_[j],
hipMemcpyDeviceToDevice));
using BDataType = remove_cvref_t<tuple_element_t<j.value, BsDataType>>;
bs_buffer(j) = static_cast<const BDataType*>(pBDeviceBuf);
});
p_bs_grids.push_back(bs_buffer);
}
{
DsGridPointer ds_buffer;
static_for<0, NumDs, 1>{}([&](auto j) {
void* pDDeviceBuf;
hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
static_cast<const void*>(p_ds_grids[0][j]),
size_ds_[j],
hipMemcpyDeviceToDevice));
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
ds_buffer(j) = static_cast<const DDataType*>(pDDeviceBuf);
});
p_ds_grids.push_back(ds_buffer);
}
}
}
void Next()
{
if(rotating_count > 1)
{
std::size_t idx = iter++ % rotating_count;
arg.p_as_grid = p_as_grids[idx];
arg.p_bs_grid = p_bs_grids[idx];
arg.p_ds_grid = p_ds_grids[idx];
}
}
void Print()
{
std::cout << "RotatingMemWrapperMultiD: { size_a: {";
static_for<0, NumAs, 1>{}(
[&](auto j) { std::cout << size_as[j] << (j.value < NumAs - 1 ? ", " : ""); });
std::cout << "}, size_b: {";
static_for<0, NumBs, 1>{}(
[&](auto j) { std::cout << size_bs[j] << (j.value < NumBs - 1 ? ", " : ""); });
std::cout << "}, rotating_count: " << rotating_count << "}" << std::endl;
}
~RotatingMemWrapperMultiABD()
{
if(rotating_count > 1)
{
// restore ptr
arg.p_as_grid = p_as_grids[0];
arg.p_bs_grid = p_bs_grids[0];
arg.p_ds_grid = p_ds_grids[0];
// free device mem
for(size_t i = 1; i < rotating_count; i++)
{
static_for<0, NumAs, 1>{}([&](auto j) {
using ADataType = remove_cvref_t<tuple_element_t<j.value, AsDataType>>;
hip_check_error(
hipFree(static_cast<void*>(const_cast<ADataType*>(p_as_grids[i][j]))));
});
static_for<0, NumBs, 1>{}([&](auto j) {
using BDataType = remove_cvref_t<tuple_element_t<j.value, BsDataType>>;
hip_check_error(
hipFree(static_cast<void*>(const_cast<BDataType*>(p_bs_grids[i][j]))));
});
static_for<0, NumDs, 1>{}([&](auto j) {
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
hip_check_error(
hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
});
}
}
}
private:
Argument& arg;
std::size_t iter = 0;
std::size_t rotating_count = 1;
std::array<std::size_t, NumAs> size_as = {0};
std::array<std::size_t, NumBs> size_bs = {0};
std::array<std::size_t, NumDs> size_ds = {0};
std::vector<AsGridPointer> p_as_grids;
std::vector<BsGridPointer> p_bs_grids;
std::vector<DsGridPointer> p_ds_grids;
};
template <typename Argument, typename DsDataType>
struct RotatingMemWrapperMultiD
{
@@ -318,6 +463,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
// total_time += cur_time;
// #endif
#if !defined(CK_USE_WMMA)
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
// std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
@@ -326,6 +472,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
static_cast<const void*>(gemm_args.p_a_grid),
static_cast<const void*>(gemm_args.p_b_grid));
}
#endif
}
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop));