Added capabilities for specialization data based on the size of the vectors passed

This commit is contained in:
Alejandro Saucedo
2020-08-29 17:03:24 +01:00
parent b80548ac3d
commit 23cf43e231
5 changed files with 140 additions and 109 deletions

View File

@@ -38,7 +38,13 @@ Algorithm::init(const std::vector<char>& shaderFileData,
// TODO: Move to util function
this->createParameters(tensorParams);
this->createShaderModule(shaderFileData);
this->createPipeline();
std::vector<uint32_t> sizes;
for (std::shared_ptr<Tensor> tensor: tensorParams) {
SPDLOG_WARN("size: {}", tensor->size());
sizes.push_back(tensor->size());
}
this->createPipeline(sizes);
}
void
@@ -152,7 +158,7 @@ Algorithm::createShaderModule(const std::vector<char>& shaderFileData)
}
void
Algorithm::createPipeline()
Algorithm::createPipeline(std::vector<uint32_t> specializationData)
{
SPDLOG_DEBUG("Kompute Algorithm calling create Pipeline");
@@ -166,12 +172,29 @@ Algorithm::createPipeline()
this->mDevice->createPipelineLayout(
&pipelineLayoutInfo, nullptr, this->mPipelineLayout.get());
std::vector<vk::SpecializationMapEntry> specializationEntries;
for (size_t i = 0; i < specializationData.size(); i++) {
vk::SpecializationMapEntry specializationEntry(
static_cast<uint32_t>(i),
static_cast<uint32_t>(sizeof(uint32_t) * i),
sizeof(uint32_t));
specializationEntries.push_back(specializationEntry);
}
vk::SpecializationInfo specializationInfo(
static_cast<uint32_t>(specializationEntries.size()),
specializationEntries.data(),
sizeof(uint32_t) * specializationEntries.size(),
specializationData.data());
vk::PipelineShaderStageCreateInfo shaderStage(
vk::PipelineShaderStageCreateFlags(),
vk::ShaderStageFlagBits::eCompute,
*this->mShaderModule,
"main",
nullptr);
&specializationInfo);
vk::ComputePipelineCreateInfo pipelineInfo(vk::PipelineCreateFlags(),
shaderStage,