mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Wmma support for multiple ABD GEMM (#2803)
* multi_abd wmma support: - Add multiple A and B support to multiple D implementation (gridwise level) - Add multi_abd GEMM (device level) - Add instances (xdl parity) - Add tests (both xdl and wmma) - Add examples - Add ckProfiler support (both xdl and wmma) * Fix bug in device print function * Fix unused template parameter * Fix batched gemm for multiABD gridwise implementation * Fix gemm_universal_reduce with multiABDs gridwise implementation --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -15,6 +15,151 @@
|
||||
namespace ck {
|
||||
namespace utility {
|
||||
|
||||
template <typename Argument, typename AsDataType, typename BsDataType, typename DsDataType>
|
||||
struct RotatingMemWrapperMultiABD
|
||||
{
|
||||
static constexpr index_t NumAs = AsDataType::Size();
|
||||
static constexpr index_t NumBs = BsDataType::Size();
|
||||
static constexpr index_t NumDs = DsDataType::Size();
|
||||
|
||||
using AsGridPointer = decltype(Argument::p_as_grid);
|
||||
using BsGridPointer = decltype(Argument::p_bs_grid);
|
||||
using DsGridPointer = decltype(Argument::p_ds_grid);
|
||||
|
||||
RotatingMemWrapperMultiABD() = delete;
|
||||
RotatingMemWrapperMultiABD(Argument& arg_,
|
||||
std::size_t rotating_count_,
|
||||
std::array<std::size_t, NumAs> size_as_,
|
||||
std::array<std::size_t, NumBs> size_bs_,
|
||||
std::array<std::size_t, NumDs> size_ds_)
|
||||
: arg(arg_),
|
||||
rotating_count(rotating_count_),
|
||||
size_as(size_as_),
|
||||
size_bs(size_bs_),
|
||||
size_ds(size_ds_)
|
||||
{
|
||||
p_as_grids.push_back(arg.p_as_grid);
|
||||
p_bs_grids.push_back(arg.p_bs_grid);
|
||||
p_ds_grids.push_back(arg.p_ds_grid);
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
{
|
||||
AsGridPointer as_buffer;
|
||||
static_for<0, NumAs, 1>{}([&](auto j) {
|
||||
void* pADeviceBuf;
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_as_[j]));
|
||||
hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
|
||||
static_cast<const void*>(p_as_grids[0][j]),
|
||||
size_as_[j],
|
||||
hipMemcpyDeviceToDevice));
|
||||
using ADataType = remove_cvref_t<tuple_element_t<j.value, AsDataType>>;
|
||||
|
||||
as_buffer(j) = static_cast<const ADataType*>(pADeviceBuf);
|
||||
});
|
||||
p_as_grids.push_back(as_buffer);
|
||||
}
|
||||
|
||||
{
|
||||
BsGridPointer bs_buffer;
|
||||
static_for<0, NumBs, 1>{}([&](auto j) {
|
||||
void* pBDeviceBuf;
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_bs_[j]));
|
||||
hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
|
||||
static_cast<const void*>(p_bs_grids[0][j]),
|
||||
size_bs_[j],
|
||||
hipMemcpyDeviceToDevice));
|
||||
using BDataType = remove_cvref_t<tuple_element_t<j.value, BsDataType>>;
|
||||
|
||||
bs_buffer(j) = static_cast<const BDataType*>(pBDeviceBuf);
|
||||
});
|
||||
p_bs_grids.push_back(bs_buffer);
|
||||
}
|
||||
|
||||
{
|
||||
DsGridPointer ds_buffer;
|
||||
static_for<0, NumDs, 1>{}([&](auto j) {
|
||||
void* pDDeviceBuf;
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
|
||||
hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
|
||||
static_cast<const void*>(p_ds_grids[0][j]),
|
||||
size_ds_[j],
|
||||
hipMemcpyDeviceToDevice));
|
||||
|
||||
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
|
||||
|
||||
ds_buffer(j) = static_cast<const DDataType*>(pDDeviceBuf);
|
||||
});
|
||||
|
||||
p_ds_grids.push_back(ds_buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Next()
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
std::size_t idx = iter++ % rotating_count;
|
||||
arg.p_as_grid = p_as_grids[idx];
|
||||
arg.p_bs_grid = p_bs_grids[idx];
|
||||
arg.p_ds_grid = p_ds_grids[idx];
|
||||
}
|
||||
}
|
||||
void Print()
|
||||
{
|
||||
std::cout << "RotatingMemWrapperMultiD: { size_a: {";
|
||||
static_for<0, NumAs, 1>{}(
|
||||
[&](auto j) { std::cout << size_as[j] << (j.value < NumAs - 1 ? ", " : ""); });
|
||||
std::cout << "}, size_b: {";
|
||||
static_for<0, NumBs, 1>{}(
|
||||
[&](auto j) { std::cout << size_bs[j] << (j.value < NumBs - 1 ? ", " : ""); });
|
||||
std::cout << "}, rotating_count: " << rotating_count << "}" << std::endl;
|
||||
}
|
||||
~RotatingMemWrapperMultiABD()
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
// restore ptr
|
||||
arg.p_as_grid = p_as_grids[0];
|
||||
arg.p_bs_grid = p_bs_grids[0];
|
||||
arg.p_ds_grid = p_ds_grids[0];
|
||||
|
||||
// free device mem
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
static_for<0, NumAs, 1>{}([&](auto j) {
|
||||
using ADataType = remove_cvref_t<tuple_element_t<j.value, AsDataType>>;
|
||||
hip_check_error(
|
||||
hipFree(static_cast<void*>(const_cast<ADataType*>(p_as_grids[i][j]))));
|
||||
});
|
||||
|
||||
static_for<0, NumBs, 1>{}([&](auto j) {
|
||||
using BDataType = remove_cvref_t<tuple_element_t<j.value, BsDataType>>;
|
||||
hip_check_error(
|
||||
hipFree(static_cast<void*>(const_cast<BDataType*>(p_bs_grids[i][j]))));
|
||||
});
|
||||
|
||||
static_for<0, NumDs, 1>{}([&](auto j) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
|
||||
hip_check_error(
|
||||
hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Argument& arg;
|
||||
std::size_t iter = 0;
|
||||
std::size_t rotating_count = 1;
|
||||
std::array<std::size_t, NumAs> size_as = {0};
|
||||
std::array<std::size_t, NumBs> size_bs = {0};
|
||||
std::array<std::size_t, NumDs> size_ds = {0};
|
||||
std::vector<AsGridPointer> p_as_grids;
|
||||
std::vector<BsGridPointer> p_bs_grids;
|
||||
std::vector<DsGridPointer> p_ds_grids;
|
||||
};
|
||||
|
||||
template <typename Argument, typename DsDataType>
|
||||
struct RotatingMemWrapperMultiD
|
||||
{
|
||||
@@ -318,6 +463,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
// total_time += cur_time;
|
||||
// #endif
|
||||
|
||||
#if !defined(CK_USE_WMMA)
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
// std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
|
||||
@@ -326,6 +472,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
static_cast<const void*>(gemm_args.p_a_grid),
|
||||
static_cast<const void*>(gemm_args.p_b_grid));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
|
||||
hip_check_error(hipEventSynchronize(stop));
|
||||
|
||||
Reference in New Issue
Block a user