mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Add rotating buff for gemm_multi_d (#1411)
* add rotating_buff for gemm_multi_d * format * Update flush_cache.hpp * Update gtest.cmake --------- Co-authored-by: Jing Zhang <jizhan@fb.com> Co-authored-by: Haocong WANG <haocwang@amd.com>
This commit is contained in:
@@ -14,6 +14,124 @@
|
||||
namespace ck {
|
||||
namespace utility {
|
||||
|
||||
template <typename Argument, typename DsDataType>
|
||||
struct RotatingMemWrapperMultiD
|
||||
{
|
||||
static constexpr index_t NumDs = DsDataType::Size();
|
||||
|
||||
using ADataType = decltype(Argument::p_a_grid);
|
||||
using BDataType = decltype(Argument::p_b_grid);
|
||||
using DsGridPointer = decltype(Argument::p_ds_grid);
|
||||
|
||||
RotatingMemWrapperMultiD() = delete;
|
||||
RotatingMemWrapperMultiD(Argument& arg_,
|
||||
std::size_t rotating_count_,
|
||||
std::size_t size_a_,
|
||||
std::size_t size_b_,
|
||||
std::array<std::size_t, NumDs> size_ds_)
|
||||
: arg(arg_),
|
||||
rotating_count(rotating_count_),
|
||||
size_a(size_a_),
|
||||
size_b(size_b_),
|
||||
size_ds(size_ds_)
|
||||
{
|
||||
p_a_grids.push_back(arg.p_a_grid);
|
||||
p_b_grids.push_back(arg.p_b_grid);
|
||||
p_ds_grids.push_back(arg.p_ds_grid);
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
{
|
||||
void* pADeviceBuf;
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
|
||||
hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
|
||||
const_cast<void*>(p_a_grids[0]),
|
||||
size_a_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
p_a_grids.push_back(pADeviceBuf);
|
||||
}
|
||||
|
||||
{
|
||||
void* pBDeviceBuf;
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
|
||||
hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
|
||||
const_cast<void*>(p_b_grids[0]),
|
||||
size_b_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
p_b_grids.push_back(pBDeviceBuf);
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
DsGridPointer ds_buffer;
|
||||
static_for<0, NumDs, 1>{}([&](auto j) {
|
||||
void* pDDeviceBuf;
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
|
||||
hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
|
||||
static_cast<const void*>(p_ds_grids[0][j]),
|
||||
size_ds_[j],
|
||||
hipMemcpyDeviceToDevice));
|
||||
|
||||
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
|
||||
|
||||
ds_buffer(j) = static_cast<const DDataType*>(pDDeviceBuf);
|
||||
});
|
||||
|
||||
p_ds_grids.push_back(ds_buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Next()
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
std::size_t idx = iter++ % rotating_count;
|
||||
arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
|
||||
arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
|
||||
arg.p_ds_grid = p_ds_grids[idx];
|
||||
}
|
||||
}
|
||||
void Print()
|
||||
{
|
||||
std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b
|
||||
<< ", rotating_count: " << rotating_count << "}" << std::endl;
|
||||
}
|
||||
~RotatingMemWrapperMultiD()
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
// restore ptr
|
||||
arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
|
||||
arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
|
||||
arg.p_ds_grid = p_ds_grids[0];
|
||||
|
||||
// free device mem
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
|
||||
hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
|
||||
|
||||
static_for<0, NumDs, 1>{}([&](auto j) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
|
||||
hip_check_error(
|
||||
hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Argument& arg;
|
||||
std::size_t iter = 0;
|
||||
std::size_t rotating_count = 1;
|
||||
std::size_t size_a = 0;
|
||||
std::size_t size_b = 0;
|
||||
std::array<std::size_t, NumDs> size_ds = {0};
|
||||
std::vector<const void*> p_a_grids;
|
||||
std::vector<const void*> p_b_grids;
|
||||
std::vector<DsGridPointer> p_ds_grids;
|
||||
};
|
||||
|
||||
template <typename Argument>
|
||||
struct RotatingMemWrapper
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user