mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
[CK_TILE] fused-moe first version (#1634)
* moe pipeline * update code * compile OK * update * update cpu reference * update pipeline_gemm0 * compiler ok * update pipeline * rename to ex pipeline * block-asm * update * update * update first gemm ok * compute correct * update file structure * update README * update * update * update code * update API * return unsupport case * add comment * update readme * update * uncomment * update * fix build err --------- Co-authored-by: valarLip <340077269@qq.com>
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include <stdint.h>
|
||||
#include <stdexcept>
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename T>
|
||||
@@ -36,6 +37,19 @@ struct DeviceMem
|
||||
mpDeviceBuf = nullptr;
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
DeviceMem(const HostTensor<T>& t) : mMemSize(t.get_element_space_size_in_bytes())
|
||||
{
|
||||
if(mMemSize != 0)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
}
|
||||
else
|
||||
{
|
||||
mpDeviceBuf = nullptr;
|
||||
}
|
||||
ToDevice(t.data());
|
||||
}
|
||||
void Realloc(std::size_t mem_size)
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
@@ -92,6 +106,27 @@ struct DeviceMem
|
||||
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
}
|
||||
|
||||
// construct a host tensor with type T
|
||||
template <typename T>
|
||||
HostTensor<T> ToHost(std::size_t cpySize)
|
||||
{
|
||||
// TODO: host tensor could be slightly larger than the device tensor
|
||||
// we just copy all data from GPU buffer
|
||||
std::size_t host_elements = (cpySize + sizeof(T) - 1) / sizeof(T);
|
||||
HostTensor<T> h_({host_elements});
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
return h_;
|
||||
}
|
||||
template <typename T>
|
||||
HostTensor<T> ToHost()
|
||||
{
|
||||
return ToHost<T>(mMemSize);
|
||||
}
|
||||
|
||||
void SetZero() const
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
|
||||
Reference in New Issue
Block a user