mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
CK Instance Gen (#1145)
* Format * Format * Format * Remove const * Use the right template * Format * Format * add row/col instances * Add missing file * fixed * fixing block to etile error * Format * Updates * Format * fixed rrr layout * generating a sample JSON file: currently contains includes, prologue/epilogue and instances * version where the json is passed into the instances to generate a key * updated run function to just launch kernel * updated run function: only contains kernel object, json file is updated but still needs to be cleaned up, added front-end API to parse JSON into character buffer * adding in testing files * cleaned up comments, still need to work on including header files * removed unneeded files * removed/commented out JSON implementation * added fusion(prologue/epilogue) into instance generation * working on instance selection * added instance selection, need to fix instance validation * removed block2etile map validity check for testing purposes * test running: failing due to incorrect files/input * all grid descs/ptrs completed, but device file not found * Update test and embed modules * Restore older version * added convolution operation, written test, debugging generated code for compilation * attempting to include CK in host directory: _Float16 error * CK header file issues * slight fix * don't crash when hip can't report total memory * dump generated code to a file * changing sizes * creating tensor descriptors using CK methods: set up grid desc manually, also trying to set up an argument pointer - this needs to be fixed * some fixes to call the device code * separating test files for conv and gemm * completed arg ptr, now have linking errors * clang format fix * resolved linker issues in conv test * remove dependency on libutility from ck * resolved num dim error * properly passing arg ptr, errors with passing typenames: redefinition/redeclaration * undo the commenting of device function * hand created kernel code to find rtc issues * dump the full src to file * resolved redeclaration errors, cleaned up errors for Amber's kernel code * debugging purposes: redeclaration error * config files * resolved errors for NumTensor and redeclaration, formatted version.h * resolved most errors in manually added kernel and my own. error with calling kernel object: overloaded function type * WIP: close to getting kernel compiled * WIP: fixing rtc errors * fixed sequence errors, formatting, still one error with run fcn * yay: kernel compiles and runs * updated templated/generated version to run and compile * minor fixes * working generated example, resolved memory access error due to padding * adding in reference kernel, validation failing against reference * debugging: printing kernel argsz * reduced error in results * debugged reference kernel and output errors, added to generated version, currently debugging prologue function issues * working validation (using reference convolution) with prologue function for both hard-coded and generated version * WIP: create an alt version that creates Argument on the device * wip: added new duplicate files, fixed fusion templating errors from working example, setting up kernel arguments * wip: making necessary methods device code * added grid descs, working on grid pointers, errors with stl numerics * wip: updating kernel args - issue, replacing some std functions * replaced std::accumulate call with temp hardcoded version * wip: args causing memory issue * Construct Argument object inside the kernel and use it to call convolution device function. Code runs and verification passes * adding object file dump * temporary hardcoding of grid size, can remove device op inst + arg ptr * minor fix for grid size * added modified example where arg ptr is created on the device for generated version as well * removed device op instance and arg ptr from modified examples * moving device op file for testing purposes and to properly build CK * commenting out print-outs * adjust compiler args to produce a valid ELF file * temporary removal of validation * reverting compiler args back for working example * retrieve necessary arguments from generated template parameters in correct format * calculating grid size on host-side, still need to clean up process, pass parameters to host functions properly * scaled up factory functions/wrapper structs to implement host-side launch parameter calculations using CK host side functions - in hard-coded example * temporary change to generate ELF format binary object file * removed unecessary code, added comments * formatting fix * cleaned up code, added new tests, restructured library: move helper into CK * refactored launch parameter calculation to be more concise * renamed files and variables for more clarity/uniformity * more code cleaning, removed debug statements * moved majority of my files into codegen directory, running properly * updated Embed.cmake(string_view) in codegen directory * updated host directory to match Embed.cmake as well * added old tests in * updated instance generation methods to be more concise * removed layout from launch parameter calculation * working test * fixed issue with verification, all instances working * updated verification in other tests * removed duplicate matrix padder file, removed code dumps * removed old hard-coded tests * removed old host directory, all files in codegen directory now * fixed copyright in files * commenting out validation * renamed files * made changes for review: fixed copyright, renamed files for clarity, removed comments, refactored code * updated headers * removing duplicate file for fwd conv to gemm, merging with original file * fix building codegen with clang++ directly * resolving build error from conv_fwd_to_gemm * fix for previous error * renaming tests * created common test file * cleaned up code, added comments * renamed device op * fixed typos in comments * removed extra space * code cleanup: resolving Amber's comments * removed wrapper struct for matrix padder, fixed template * cleaned up if statements for better readability --------- Co-authored-by: Paul <pfultz2@yahoo.com> Co-authored-by: Jing Zhang <jizha@amd.com> Co-authored-by: M. Amber Hassaan <amber_474@yahoo.com> Co-authored-by: illsilin <Illia.Silin@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
359
include/ck/tensor_operation/gpu/device/helper.hpp
Normal file
359
include/ck/tensor_operation/gpu/device/helper.hpp
Normal file
@@ -0,0 +1,359 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include <fstream>
|
||||
#include <variant>
|
||||
|
||||
// functions to return the corresponding structs based on generated template parameters
|
||||
|
||||
using layouts = std::variant<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::tensor_layout::convolution::NHWGK,
|
||||
ck::tensor_layout::convolution::GNDHWK,
|
||||
ck::tensor_layout::convolution::NDHWGK>;
|
||||
// return the layout type: currently this is the only type supported in MIOpen
|
||||
auto layout_type(std::string type)
|
||||
{
|
||||
if(type == "ck::tensor_layout::convolution::NHWGK")
|
||||
{
|
||||
return ck::tensor_layout::convolution::NHWGK{};
|
||||
}
|
||||
throw std::runtime_error("Incorrect layout");
|
||||
}
|
||||
// return the right gemm spec based on the generated template parameters
|
||||
ck::tensor_operation::device::GemmSpecialization gemm_type(std::string type)
|
||||
{
|
||||
if(type == "ck::tensor_operation::device::GemmSpecialization::Default")
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
}
|
||||
if(type == "ck::tensor_operation::device::GemmSpecialization::MNKPadding")
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
}
|
||||
throw std::runtime_error("Incorrect gemm spec: " + type);
|
||||
}
|
||||
|
||||
// return the type of convolution
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization conv_type(std::string type)
|
||||
{
|
||||
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default")
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
}
|
||||
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0")
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
|
||||
}
|
||||
if(type ==
|
||||
"ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0")
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
|
||||
}
|
||||
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC")
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
|
||||
}
|
||||
throw std::runtime_error("Incorrect conv spec: " + type);
|
||||
}
|
||||
|
||||
// Function to call on MatrixPadder via a wrapper struct
|
||||
// NOTE: CK only uses MNKPadding for forward convolution
|
||||
template <typename CDesc_MRaw_NRaw>
|
||||
auto pad(ck::index_t mpb,
|
||||
ck::index_t npb,
|
||||
ck::index_t kpb,
|
||||
ck::tensor_operation::device::GemmSpecialization gemm,
|
||||
CDesc_MRaw_NRaw conv)
|
||||
{
|
||||
if(gemm == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
|
||||
{
|
||||
ck::tensor_operation::device::MatrixPadder<
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
ck::index_t,
|
||||
ck::index_t,
|
||||
ck::index_t>
|
||||
a;
|
||||
a.MPerTile_ = mpb;
|
||||
a.NPerTile_ = npb;
|
||||
a.KPerTile_ = kpb;
|
||||
auto tmp = grid_desc(a, conv);
|
||||
return tmp;
|
||||
}
|
||||
throw std::runtime_error("Incorrect template parameters, check gemm spec");
|
||||
}
|
||||
|
||||
// Functions to call on TransformConvFwdToGemm through wrapper: different functions based on num
|
||||
// dims
|
||||
// FIXME: add a way to properly pass in the layout
|
||||
auto transform_conv(ck::index_t num_dim,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
|
||||
ck::Array<ck::index_t, 5> out_lengths,
|
||||
ck::Array<ck::index_t, 5> out_strides)
|
||||
{
|
||||
if(num_dim == 2 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 2 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 2 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
throw std::runtime_error("Incorrect conv spec");
|
||||
}
|
||||
|
||||
auto transform_conv_3d(ck::index_t num_dim,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
|
||||
ck::Array<ck::index_t, 6> out_lengths,
|
||||
ck::Array<ck::index_t, 6> out_strides)
|
||||
{
|
||||
if(num_dim == 3 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 3 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 3 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
throw std::runtime_error("Incorrect conv spec");
|
||||
}
|
||||
|
||||
auto transform_conv_1d(ck::index_t num_dim,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
|
||||
ck::Array<ck::index_t, 4> out_lengths,
|
||||
ck::Array<ck::index_t, 4> out_strides)
|
||||
{
|
||||
if(num_dim == 1 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 1 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 1 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
throw std::runtime_error("Incorrect dims or conv spec");
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
auto block_2_etile(ck::index_t m_per_block, ck::index_t n_per_block, CGridDesc_M_N matrix_padder)
|
||||
{
|
||||
if(m_per_block == 32 && n_per_block == 64)
|
||||
{
|
||||
auto b2e = ck::BlockToCTileMap_M00_N0_M01Adapt<32, 64, CGridDesc_M_N>(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 32 && n_per_block == 128)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<32, 128, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 64 && n_per_block == 32)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 32, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 64 && n_per_block == 64)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 64, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 64 && n_per_block == 128)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 128, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 128 && n_per_block == 32)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 32, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 128 && n_per_block == 64)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 64, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 128 && n_per_block == 128)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 128, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 128 && n_per_block == 256)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 256, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 256 && n_per_block == 128)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<256, 128, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
throw std::runtime_error("Incorrect template parameters");
|
||||
}
|
||||
|
||||
// wrapper functions by dims to get grid size - uses above 3 functions
|
||||
// TODO: eventually remove the 1d/2d versions as CK will only support 3d convolutions
|
||||
auto get_launch_params_1d(ck::host::Solution solution,
|
||||
ck::Array<ck::index_t, 4> out_lengths,
|
||||
ck::Array<ck::index_t, 4> out_strides)
|
||||
{
|
||||
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
|
||||
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
|
||||
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
|
||||
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
|
||||
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
|
||||
auto conv_to_gemm_transformer = transform_conv_1d(num_dim, ConvSpec, out_lengths, out_strides);
|
||||
auto matrix_padder =
|
||||
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
|
||||
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
|
||||
return b2e;
|
||||
}
|
||||
|
||||
auto get_launch_params(ck::host::Solution solution,
|
||||
ck::Array<ck::index_t, 5> out_lengths,
|
||||
ck::Array<ck::index_t, 5> out_strides)
|
||||
{
|
||||
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
|
||||
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
|
||||
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
|
||||
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
|
||||
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
|
||||
auto conv_to_gemm_transformer = transform_conv(num_dim, ConvSpec, out_lengths, out_strides);
|
||||
auto matrix_padder =
|
||||
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
|
||||
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
|
||||
return b2e;
|
||||
}
|
||||
|
||||
auto get_launch_params_3d(ck::host::Solution solution,
|
||||
ck::Array<ck::index_t, 6> out_lengths,
|
||||
ck::Array<ck::index_t, 6> out_strides)
|
||||
{
|
||||
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
|
||||
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
|
||||
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
|
||||
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
|
||||
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
|
||||
auto conv_to_gemm_transformer = transform_conv_3d(num_dim, ConvSpec, out_lengths, out_strides);
|
||||
auto matrix_padder =
|
||||
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
|
||||
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
|
||||
return b2e;
|
||||
}
|
||||
Reference in New Issue
Block a user