From 3f190b0779ceedf7aaf0b380712fda0518de72c1 Mon Sep 17 00:00:00 2001 From: "Po-Yen, Chen" Date: Sun, 21 Aug 2022 01:16:47 -0400 Subject: [PATCH] Add new DeviceMem API to copy memory --- .../ck/library/utility/device_memory.hpp | 41 +++++++++++++++++-- library/src/utility/device_memory.cpp | 29 +++++++++++-- 2 files changed, 64 insertions(+), 6 deletions(-) diff --git a/library/include/ck/library/utility/device_memory.hpp b/library/include/ck/library/utility/device_memory.hpp index 3c4ece4406..90a7ab7be4 100644 --- a/library/include/ck/library/utility/device_memory.hpp +++ b/library/include/ck/library/utility/device_memory.hpp @@ -3,8 +3,12 @@ #pragma once +#include + #include +#include "ck/utility/data_type.hpp" + template __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) { @@ -14,19 +18,50 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) } } -struct DeviceMem +class DeviceMem { + void ToDeviceImpl(const void* p) const; + void FromDeviceImpl(void* p) const; + + public: DeviceMem() = delete; DeviceMem(std::size_t mem_size); void* GetDeviceBuffer() const; std::size_t GetBufferSize() const; - void ToDevice(const void* p) const; - void FromDevice(void* p) const; void SetZero() const; template void SetValue(T x) const; ~DeviceMem(); + template + std::enable_if_t && + (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v)> + FromDevice(T* host) const + { + assert(device.GetBufferSize() % sizeof(T) == 0); + + FromDeviceImpl(host); + } + + template + std::enable_if_t || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v> + ToDevice(const T* host) const + { + assert(device.GetBufferSize() % sizeof(T) == 0); + + ToDeviceImpl(host); + } + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + void ToDevice(const ck::int4_t* host) const; + void FromDevice(ck::int4_t* host) const; +#endif + void* mpDeviceBuf; std::size_t mMemSize; }; diff --git a/library/src/utility/device_memory.cpp b/library/src/utility/device_memory.cpp index 90f943313b..96a3cb510e 100644 --- a/library/src/utility/device_memory.cpp +++ b/library/src/utility/device_memory.cpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include "ck/host_utility/hip_check_error.hpp" +#include +#include "ck/host_utility/hip_check_error.hpp" #include "ck/library/utility/device_memory.hpp" DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) @@ -14,16 +15,38 @@ void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; } std::size_t DeviceMem::GetBufferSize() const { return mMemSize; } -void DeviceMem::ToDevice(const void* p) const +void DeviceMem::ToDeviceImpl(const void* p) const { hip_check_error(hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); } -void DeviceMem::FromDevice(void* p) const +void DeviceMem::FromDeviceImpl(void* p) const { hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); } +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +void DeviceMem::ToDevice(const ck::int4_t* host) const +{ + const std::size_t count = GetBufferSize(); + const auto buffer = std::make_unique(count); + + std::copy_n(host, count, buffer.get()); + + ToDevice(buffer.get()); +} + +void DeviceMem::FromDevice(ck::int4_t* host) const +{ + const std::size_t count = GetBufferSize(); + const auto buffer = std::make_unique(count); + + FromDevice(buffer.get()); + + std::copy_n(buffer.get(), count, host); +} +#endif + void DeviceMem::SetZero() const { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); } DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); }