mirror of
https://github.com/nomic-ai/kompute.git
synced 2026-05-12 01:19:58 +00:00
Added capabilities for specialization data based on the size of the vectors passed
This commit is contained in:
@@ -38,7 +38,13 @@ Algorithm::init(const std::vector<char>& shaderFileData,
|
||||
// TODO: Move to util function
|
||||
this->createParameters(tensorParams);
|
||||
this->createShaderModule(shaderFileData);
|
||||
this->createPipeline();
|
||||
|
||||
std::vector<uint32_t> sizes;
|
||||
for (std::shared_ptr<Tensor> tensor: tensorParams) {
|
||||
SPDLOG_WARN("size: {}", tensor->size());
|
||||
sizes.push_back(tensor->size());
|
||||
}
|
||||
this->createPipeline(sizes);
|
||||
}
|
||||
|
||||
void
|
||||
@@ -152,7 +158,7 @@ Algorithm::createShaderModule(const std::vector<char>& shaderFileData)
|
||||
}
|
||||
|
||||
void
|
||||
Algorithm::createPipeline()
|
||||
Algorithm::createPipeline(std::vector<uint32_t> specializationData)
|
||||
{
|
||||
SPDLOG_DEBUG("Kompute Algorithm calling create Pipeline");
|
||||
|
||||
@@ -166,12 +172,29 @@ Algorithm::createPipeline()
|
||||
this->mDevice->createPipelineLayout(
|
||||
&pipelineLayoutInfo, nullptr, this->mPipelineLayout.get());
|
||||
|
||||
std::vector<vk::SpecializationMapEntry> specializationEntries;
|
||||
|
||||
for (size_t i = 0; i < specializationData.size(); i++) {
|
||||
vk::SpecializationMapEntry specializationEntry(
|
||||
static_cast<uint32_t>(i),
|
||||
static_cast<uint32_t>(sizeof(uint32_t) * i),
|
||||
sizeof(uint32_t));
|
||||
|
||||
specializationEntries.push_back(specializationEntry);
|
||||
}
|
||||
|
||||
vk::SpecializationInfo specializationInfo(
|
||||
static_cast<uint32_t>(specializationEntries.size()),
|
||||
specializationEntries.data(),
|
||||
sizeof(uint32_t) * specializationEntries.size(),
|
||||
specializationData.data());
|
||||
|
||||
vk::PipelineShaderStageCreateInfo shaderStage(
|
||||
vk::PipelineShaderStageCreateFlags(),
|
||||
vk::ShaderStageFlagBits::eCompute,
|
||||
*this->mShaderModule,
|
||||
"main",
|
||||
nullptr);
|
||||
&specializationInfo);
|
||||
|
||||
vk::ComputePipelineCreateInfo pipelineInfo(vk::PipelineCreateFlags(),
|
||||
shaderStage,
|
||||
|
||||
Reference in New Issue
Block a user