mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Add rotating buff for gemm_multi_d (#1411)
* add rotating_buff for gemm_multi_d * format * Update flush_cache.hpp * Update gtest.cmake --------- Co-authored-by: Jing Zhang <jizhan@fb.com> Co-authored-by: Haocong WANG <haocwang@amd.com>
This commit is contained in:
@@ -14,6 +14,124 @@
|
||||
namespace ck {
|
||||
namespace utility {
|
||||
|
||||
template <typename Argument, typename DsDataType>
|
||||
struct RotatingMemWrapperMultiD
|
||||
{
|
||||
static constexpr index_t NumDs = DsDataType::Size();
|
||||
|
||||
using ADataType = decltype(Argument::p_a_grid);
|
||||
using BDataType = decltype(Argument::p_b_grid);
|
||||
using DsGridPointer = decltype(Argument::p_ds_grid);
|
||||
|
||||
RotatingMemWrapperMultiD() = delete;
|
||||
RotatingMemWrapperMultiD(Argument& arg_,
|
||||
std::size_t rotating_count_,
|
||||
std::size_t size_a_,
|
||||
std::size_t size_b_,
|
||||
std::array<std::size_t, NumDs> size_ds_)
|
||||
: arg(arg_),
|
||||
rotating_count(rotating_count_),
|
||||
size_a(size_a_),
|
||||
size_b(size_b_),
|
||||
size_ds(size_ds_)
|
||||
{
|
||||
p_a_grids.push_back(arg.p_a_grid);
|
||||
p_b_grids.push_back(arg.p_b_grid);
|
||||
p_ds_grids.push_back(arg.p_ds_grid);
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
{
|
||||
void* pADeviceBuf;
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
|
||||
hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
|
||||
const_cast<void*>(p_a_grids[0]),
|
||||
size_a_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
p_a_grids.push_back(pADeviceBuf);
|
||||
}
|
||||
|
||||
{
|
||||
void* pBDeviceBuf;
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
|
||||
hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
|
||||
const_cast<void*>(p_b_grids[0]),
|
||||
size_b_,
|
||||
hipMemcpyDeviceToDevice));
|
||||
p_b_grids.push_back(pBDeviceBuf);
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
DsGridPointer ds_buffer;
|
||||
static_for<0, NumDs, 1>{}([&](auto j) {
|
||||
void* pDDeviceBuf;
|
||||
hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
|
||||
hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
|
||||
static_cast<const void*>(p_ds_grids[0][j]),
|
||||
size_ds_[j],
|
||||
hipMemcpyDeviceToDevice));
|
||||
|
||||
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
|
||||
|
||||
ds_buffer(j) = static_cast<const DDataType*>(pDDeviceBuf);
|
||||
});
|
||||
|
||||
p_ds_grids.push_back(ds_buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Next()
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
std::size_t idx = iter++ % rotating_count;
|
||||
arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
|
||||
arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
|
||||
arg.p_ds_grid = p_ds_grids[idx];
|
||||
}
|
||||
}
|
||||
void Print()
|
||||
{
|
||||
std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b
|
||||
<< ", rotating_count: " << rotating_count << "}" << std::endl;
|
||||
}
|
||||
~RotatingMemWrapperMultiD()
|
||||
{
|
||||
if(rotating_count > 1)
|
||||
{
|
||||
// restore ptr
|
||||
arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
|
||||
arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
|
||||
arg.p_ds_grid = p_ds_grids[0];
|
||||
|
||||
// free device mem
|
||||
for(size_t i = 1; i < rotating_count; i++)
|
||||
{
|
||||
hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
|
||||
hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
|
||||
|
||||
static_for<0, NumDs, 1>{}([&](auto j) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
|
||||
hip_check_error(
|
||||
hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Argument& arg;
|
||||
std::size_t iter = 0;
|
||||
std::size_t rotating_count = 1;
|
||||
std::size_t size_a = 0;
|
||||
std::size_t size_b = 0;
|
||||
std::array<std::size_t, NumDs> size_ds = {0};
|
||||
std::vector<const void*> p_a_grids;
|
||||
std::vector<const void*> p_b_grids;
|
||||
std::vector<DsGridPointer> p_ds_grids;
|
||||
};
|
||||
|
||||
template <typename Argument>
|
||||
struct RotatingMemWrapper
|
||||
{
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -163,14 +164,65 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
std::array<std::size_t, NumDTensor> DsSize;
|
||||
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
|
||||
arg_,
|
||||
stream_config.rotating_count,
|
||||
arg_.M * arg_.K * sizeof(ADataType),
|
||||
arg_.K * arg_.N * sizeof(BDataType),
|
||||
DsSize);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
|
||||
{
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
|
||||
Reference in New Issue
Block a user