mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
[CK_TILE] fused-moe first version (#1634)
* moe pipeline
* update code
* compile OK
* update
* update cpu reference
* update pipeline_gemm0
* compiler ok
* update pipeline
* rename to ex pipeline
* block-asm
* update
* update
* update first gemm ok
* compute correct
* update file structure
* update README
* update
* update
* update code
* update API
* return unsupport case
* add comment
* update readme
* update
* uncomment
* update
* fix build err
---------
Co-authored-by: valarLip <340077269@qq.com>
[ROCm/composable_kernel commit: 440e28b08f]
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include <stdint.h>
|
||||
#include <stdexcept>
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename T>
|
||||
@@ -36,6 +37,19 @@ struct DeviceMem
|
||||
mpDeviceBuf = nullptr;
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
DeviceMem(const HostTensor<T>& t) : mMemSize(t.get_element_space_size_in_bytes())
|
||||
{
|
||||
if(mMemSize != 0)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
|
||||
}
|
||||
else
|
||||
{
|
||||
mpDeviceBuf = nullptr;
|
||||
}
|
||||
ToDevice(t.data());
|
||||
}
|
||||
void Realloc(std::size_t mem_size)
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
@@ -92,6 +106,27 @@ struct DeviceMem
|
||||
HIP_CHECK_ERROR(hipMemcpy(p, mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
}
|
||||
|
||||
// construct a host tensor with type T
|
||||
template <typename T>
|
||||
HostTensor<T> ToHost(std::size_t cpySize)
|
||||
{
|
||||
// TODO: host tensor could be slightly larger than the device tensor
|
||||
// we just copy all data from GPU buffer
|
||||
std::size_t host_elements = (cpySize + sizeof(T) - 1) / sizeof(T);
|
||||
HostTensor<T> h_({host_elements});
|
||||
if(mpDeviceBuf)
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMemcpy(h_.data(), mpDeviceBuf, cpySize, hipMemcpyDeviceToHost));
|
||||
}
|
||||
return h_;
|
||||
}
|
||||
template <typename T>
|
||||
HostTensor<T> ToHost()
|
||||
{
|
||||
return ToHost<T>(mMemSize);
|
||||
}
|
||||
|
||||
void SetZero() const
|
||||
{
|
||||
if(mpDeviceBuf)
|
||||
|
||||
Reference in New Issue
Block a user