mirror of
https://github.com/nomic-ai/kompute.git
synced 2026-05-11 17:09:59 +00:00
123 lines
4.1 KiB
C++
123 lines
4.1 KiB
C++
#pragma once
|
|
|
|
#include "kompute/operations/OpAlgoLhsRhsOut.hpp"
|
|
|
|
namespace kp {
|
|
|
|
OpAlgoLhsRhsOut::OpAlgoLhsRhsOut()
|
|
{
|
|
SPDLOG_DEBUG("Kompute OpAlgoLhsRhsOut constructor base");
|
|
}
|
|
|
|
OpAlgoLhsRhsOut::OpAlgoLhsRhsOut(
|
|
std::shared_ptr<vk::PhysicalDevice> physicalDevice,
|
|
std::shared_ptr<vk::Device> device,
|
|
std::shared_ptr<vk::CommandBuffer> commandBuffer,
|
|
std::vector<std::shared_ptr<Tensor>> tensors,
|
|
const Workgroup& komputeWorkgroup)
|
|
// The inheritance is initialised with the copyOutputData to false given that
|
|
// this depencendant class handles the transfer of data via staging buffers in
|
|
// a granular way.
|
|
: OpAlgoBase(physicalDevice, device, commandBuffer, tensors, komputeWorkgroup)
|
|
{
|
|
SPDLOG_DEBUG("Kompute OpAlgoLhsRhsOut constructor with params");
|
|
}
|
|
|
|
OpAlgoLhsRhsOut::~OpAlgoLhsRhsOut()
|
|
{
|
|
SPDLOG_DEBUG("Kompute OpAlgoLhsRhsOut destructor started");
|
|
}
|
|
|
|
void
|
|
OpAlgoLhsRhsOut::init()
|
|
{
|
|
SPDLOG_DEBUG("Kompute OpAlgoLhsRhsOut init called");
|
|
|
|
if (this->mTensors.size() < 3) {
|
|
throw std::runtime_error(
|
|
"Kompute OpAlgoLhsRhsOut called with less than 1 tensor");
|
|
} else if (this->mTensors.size() > 3) {
|
|
SPDLOG_WARN(
|
|
"Kompute OpAlgoLhsRhsOut called with more than 3 this->mTensors");
|
|
}
|
|
|
|
this->mTensorLHS = this->mTensors[0];
|
|
this->mTensorRHS = this->mTensors[1];
|
|
this->mTensorOutput = this->mTensors[2];
|
|
|
|
if (!(this->mTensorLHS->isInit() && this->mTensorRHS->isInit() &&
|
|
this->mTensorOutput->isInit())) {
|
|
throw std::runtime_error(
|
|
"Kompute OpAlgoLhsRhsOut all tensor parameters must be initialised. "
|
|
"LHS: " +
|
|
std::to_string(this->mTensorLHS->isInit()) +
|
|
" RHS: " + std::to_string(this->mTensorRHS->isInit()) +
|
|
" Output: " + std::to_string(this->mTensorOutput->isInit()));
|
|
}
|
|
|
|
if (!(this->mTensorLHS->size() == this->mTensorRHS->size() &&
|
|
this->mTensorRHS->size() == this->mTensorOutput->size())) {
|
|
throw std::runtime_error(
|
|
"Kompute OpAlgoLhsRhsOut all tensor parameters must be the same size "
|
|
"LHS: " +
|
|
std::to_string(this->mTensorLHS->size()) +
|
|
" RHS: " + std::to_string(this->mTensorRHS->size()) +
|
|
" Output: " + std::to_string(this->mTensorOutput->size()));
|
|
}
|
|
|
|
SPDLOG_DEBUG("Kompute OpAlgoLhsRhsOut fetching spirv data");
|
|
|
|
std::vector<uint32_t> shaderFileData = this->fetchSpirvBinaryData();
|
|
|
|
SPDLOG_DEBUG("Kompute OpAlgoLhsRhsOut Initialising algorithm component");
|
|
|
|
this->mAlgorithm->init(shaderFileData, this->mTensors);
|
|
}
|
|
|
|
void
|
|
OpAlgoLhsRhsOut::record()
|
|
{
|
|
SPDLOG_DEBUG("Kompute OpAlgoLhsRhsOut record called");
|
|
|
|
// Barrier to ensure the data is finished writing to buffer memory
|
|
this->mTensorLHS->recordBufferMemoryBarrier(
|
|
this->mCommandBuffer,
|
|
vk::AccessFlagBits::eHostWrite,
|
|
vk::AccessFlagBits::eShaderRead,
|
|
vk::PipelineStageFlagBits::eHost,
|
|
vk::PipelineStageFlagBits::eComputeShader);
|
|
this->mTensorRHS->recordBufferMemoryBarrier(
|
|
this->mCommandBuffer,
|
|
vk::AccessFlagBits::eHostWrite,
|
|
vk::AccessFlagBits::eShaderRead,
|
|
vk::PipelineStageFlagBits::eHost,
|
|
vk::PipelineStageFlagBits::eComputeShader);
|
|
|
|
this->mAlgorithm->recordDispatch(this->mKomputeWorkgroup[0],
|
|
this->mKomputeWorkgroup[1],
|
|
this->mKomputeWorkgroup[2]);
|
|
|
|
// Barrier to ensure the shader code is executed before buffer read
|
|
this->mTensorOutput->recordBufferMemoryBarrier(
|
|
this->mCommandBuffer,
|
|
vk::AccessFlagBits::eShaderWrite,
|
|
vk::AccessFlagBits::eTransferRead,
|
|
vk::PipelineStageFlagBits::eComputeShader,
|
|
vk::PipelineStageFlagBits::eTransfer);
|
|
|
|
if (this->mTensorOutput->tensorType() == Tensor::TensorTypes::eDevice) {
|
|
this->mTensorOutput->recordCopyFromDeviceToStaging(this->mCommandBuffer,
|
|
true);
|
|
}
|
|
}
|
|
|
|
void
|
|
OpAlgoLhsRhsOut::postEval()
|
|
{
|
|
SPDLOG_DEBUG("Kompute OpAlgoLhsRhsOut postSubmit called");
|
|
|
|
this->mTensorOutput->mapDataFromHostMemory();
|
|
}
|
|
|
|
}
|