mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
CK Instance Gen (#1145)
* Format * Format * Format * Remove const * Use the right template * Format * Format * add row/col instances * Add missing file * fixed * fixing block to etile error * Format * Updates * Format * fixed rrr layout * generating a sample JSON file: currently contains includes, prologue/epilogue and instances * version where the json is passed into the instances to generate a key * updated run function to just launch kernel * updated run function: only contains kernel object, json file is updated but still needs to be cleaned up, added front-end API to parse JSON into character buffer * adding in testing files * cleaned up comments, still need to work on including header files * removed unneeded files * removed/commented out JSON implementation * added fusion(prologue/epilogue) into instance generation * working on instance selection * added instance selection, need to fix instance validation * removed block2etile map validity check for testing purposes * test running: failing due to incorrect files/input * all grid descs/ptrs completed, but device file not found * Update test and embed modules * Restore older version * added convolution operation, written test, debugging generated code for compilation * attempting to include CK in host directory: _Float16 error * CK header file issues * slight fix * don't crash when hip can't report total memory * dump generated code to a file * changing sizes * creating tensor descriptors using CK methods: set up grid desc manually, also trying to set up an argument pointer - this needs to be fixed * some fixes to call the device code * separating test files for conv and gemm * completed arg ptr, now have linking errors * clang format fix * resolved linker issues in conv test * remove dependency on libutility from ck * resolved num dim error * properly passing arg ptr, errors with passing typenames: redefinition/redeclaration * undo the commenting of device function * hand created kernel code to find rtc issues * dump the full src to file * resolved redeclaration errors, cleaned up errors for Amber's kernel code * debugging purposes: redeclaration error * config files * resolved errors for NumTensor and redeclaration, formatted version.h * resolved most errors in manually added kernel and my own. error with calling kernel object: overloaded function type * WIP: close to getting kernel compiled * WIP: fixing rtc errors * fixed sequence errors, formatting, still one error with run fcn * yay: kernel compiles and runs * updated templated/generated version to run and compile * minor fixes * working generated example, resolved memory access error due to padding * adding in reference kernel, validation failing against reference * debugging: printing kernel argsz * reduced error in results * debugged reference kernel and output errors, added to generated version, currently debugging prologue function issues * working validation (using reference convolution) with prologue function for both hard-coded and generated version * WIP: create an alt version that creates Argument on the device * wip: added new duplicate files, fixed fusion templating errors from working example, setting up kernel arguments * wip: making necessary methods device code * added grid descs, working on grid pointers, errors with stl numerics * wip: updating kernel args - issue, replacing some std functions * replaced std::accumulate call with temp hardcoded version * wip: args causing memory issue * Construct Argument object inside the kernel and use it to call convolution device function. Code runs and verification passes * adding object file dump * temporary hardcoding of grid size, can remove device op inst + arg ptr * minor fix for grid size * added modified example where arg ptr is created on the device for generated version as well * removed device op instance and arg ptr from modified examples * moving device op file for testing purposes and to properly build CK * commenting out print-outs * adjust compiler args to produce a valid ELF file * temporary removal of validation * reverting compiler args back for working example * retrieve necessary arguments from generated template parameters in correct format * calculating grid size on host-side, still need to clean up process, pass parameters to host functions properly * scaled up factory functions/wrapper structs to implement host-side launch parameter calculations using CK host side functions - in hard-coded example * temporary change to generate ELF format binary object file * removed unecessary code, added comments * formatting fix * cleaned up code, added new tests, restructured library: move helper into CK * refactored launch parameter calculation to be more concise * renamed files and variables for more clarity/uniformity * more code cleaning, removed debug statements * moved majority of my files into codegen directory, running properly * updated Embed.cmake(string_view) in codegen directory * updated host directory to match Embed.cmake as well * added old tests in * updated instance generation methods to be more concise * removed layout from launch parameter calculation * working test * fixed issue with verification, all instances working * updated verification in other tests * removed duplicate matrix padder file, removed code dumps * removed old hard-coded tests * removed old host directory, all files in codegen directory now * fixed copyright in files * commenting out validation * renamed files * made changes for review: fixed copyright, renamed files for clarity, removed comments, refactored code * updated headers * removing duplicate file for fwd conv to gemm, merging with original file * fix building codegen with clang++ directly * resolving build error from conv_fwd_to_gemm * fix for previous error * renaming tests * created common test file * cleaned up code, added comments * renamed device op * fixed typos in comments * removed extra space * code cleanup: resolving Amber's comments * removed wrapper struct for matrix padder, fixed template * cleaned up if statements for better readability --------- Co-authored-by: Paul <pfultz2@yahoo.com> Co-authored-by: Jing Zhang <jizha@amd.com> Co-authored-by: M. Amber Hassaan <amber_474@yahoo.com> Co-authored-by: illsilin <Illia.Silin@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
359
include/ck/tensor_operation/gpu/device/helper.hpp
Normal file
359
include/ck/tensor_operation/gpu/device/helper.hpp
Normal file
@@ -0,0 +1,359 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include <fstream>
|
||||
#include <variant>
|
||||
|
||||
// functions to return the corresponding structs based on generated template parameters
|
||||
|
||||
using layouts = std::variant<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::tensor_layout::convolution::NHWGK,
|
||||
ck::tensor_layout::convolution::GNDHWK,
|
||||
ck::tensor_layout::convolution::NDHWGK>;
|
||||
// return the layout type: currently this is the only type supported in MIOpen
|
||||
auto layout_type(std::string type)
|
||||
{
|
||||
if(type == "ck::tensor_layout::convolution::NHWGK")
|
||||
{
|
||||
return ck::tensor_layout::convolution::NHWGK{};
|
||||
}
|
||||
throw std::runtime_error("Incorrect layout");
|
||||
}
|
||||
// return the right gemm spec based on the generated template parameters
|
||||
ck::tensor_operation::device::GemmSpecialization gemm_type(std::string type)
|
||||
{
|
||||
if(type == "ck::tensor_operation::device::GemmSpecialization::Default")
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
}
|
||||
if(type == "ck::tensor_operation::device::GemmSpecialization::MNKPadding")
|
||||
{
|
||||
return ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
}
|
||||
throw std::runtime_error("Incorrect gemm spec: " + type);
|
||||
}
|
||||
|
||||
// return the type of convolution
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization conv_type(std::string type)
|
||||
{
|
||||
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default")
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
}
|
||||
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0")
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
|
||||
}
|
||||
if(type ==
|
||||
"ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0")
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
|
||||
}
|
||||
if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC")
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
|
||||
}
|
||||
throw std::runtime_error("Incorrect conv spec: " + type);
|
||||
}
|
||||
|
||||
// Function to call on MatrixPadder via a wrapper struct
|
||||
// NOTE: CK only uses MNKPadding for forward convolution
|
||||
template <typename CDesc_MRaw_NRaw>
|
||||
auto pad(ck::index_t mpb,
|
||||
ck::index_t npb,
|
||||
ck::index_t kpb,
|
||||
ck::tensor_operation::device::GemmSpecialization gemm,
|
||||
CDesc_MRaw_NRaw conv)
|
||||
{
|
||||
if(gemm == ck::tensor_operation::device::GemmSpecialization::MNKPadding)
|
||||
{
|
||||
ck::tensor_operation::device::MatrixPadder<
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
ck::index_t,
|
||||
ck::index_t,
|
||||
ck::index_t>
|
||||
a;
|
||||
a.MPerTile_ = mpb;
|
||||
a.NPerTile_ = npb;
|
||||
a.KPerTile_ = kpb;
|
||||
auto tmp = grid_desc(a, conv);
|
||||
return tmp;
|
||||
}
|
||||
throw std::runtime_error("Incorrect template parameters, check gemm spec");
|
||||
}
|
||||
|
||||
// Functions to call on TransformConvFwdToGemm through wrapper: different functions based on num
|
||||
// dims
|
||||
// FIXME: add a way to properly pass in the layout
|
||||
auto transform_conv(ck::index_t num_dim,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
|
||||
ck::Array<ck::index_t, 5> out_lengths,
|
||||
ck::Array<ck::index_t, 5> out_strides)
|
||||
{
|
||||
if(num_dim == 2 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 2 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 2 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
throw std::runtime_error("Incorrect conv spec");
|
||||
}
|
||||
|
||||
auto transform_conv_3d(ck::index_t num_dim,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
|
||||
ck::Array<ck::index_t, 6> out_lengths,
|
||||
ck::Array<ck::index_t, 6> out_strides)
|
||||
{
|
||||
if(num_dim == 3 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 3 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 3 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
throw std::runtime_error("Incorrect conv spec");
|
||||
}
|
||||
|
||||
auto transform_conv_1d(ck::index_t num_dim,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization spec,
|
||||
ck::Array<ck::index_t, 4> out_lengths,
|
||||
ck::Array<ck::index_t, 4> out_strides)
|
||||
{
|
||||
if(num_dim == 1 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 1 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 1 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
|
||||
conv_fwd;
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
}
|
||||
throw std::runtime_error("Incorrect dims or conv spec");
|
||||
}
|
||||
|
||||
template <typename CGridDesc_M_N>
|
||||
auto block_2_etile(ck::index_t m_per_block, ck::index_t n_per_block, CGridDesc_M_N matrix_padder)
|
||||
{
|
||||
if(m_per_block == 32 && n_per_block == 64)
|
||||
{
|
||||
auto b2e = ck::BlockToCTileMap_M00_N0_M01Adapt<32, 64, CGridDesc_M_N>(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 32 && n_per_block == 128)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<32, 128, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 64 && n_per_block == 32)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 32, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 64 && n_per_block == 64)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 64, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 64 && n_per_block == 128)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<64, 128, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 128 && n_per_block == 32)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 32, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 128 && n_per_block == 64)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 64, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 128 && n_per_block == 128)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 128, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 128 && n_per_block == 256)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<128, 256, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
if(m_per_block == 256 && n_per_block == 128)
|
||||
{
|
||||
ck::BlockToCTileMap_M00_N0_M01Adapt<256, 128, CGridDesc_M_N> b2e(matrix_padder);
|
||||
return b2e.CalculateGridSize(matrix_padder);
|
||||
}
|
||||
throw std::runtime_error("Incorrect template parameters");
|
||||
}
|
||||
|
||||
// wrapper functions by dims to get grid size - uses above 3 functions
|
||||
// TODO: eventually remove the 1d/2d versions as CK will only support 3d convolutions
|
||||
auto get_launch_params_1d(ck::host::Solution solution,
|
||||
ck::Array<ck::index_t, 4> out_lengths,
|
||||
ck::Array<ck::index_t, 4> out_strides)
|
||||
{
|
||||
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
|
||||
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
|
||||
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
|
||||
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
|
||||
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
|
||||
auto conv_to_gemm_transformer = transform_conv_1d(num_dim, ConvSpec, out_lengths, out_strides);
|
||||
auto matrix_padder =
|
||||
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
|
||||
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
|
||||
return b2e;
|
||||
}
|
||||
|
||||
auto get_launch_params(ck::host::Solution solution,
|
||||
ck::Array<ck::index_t, 5> out_lengths,
|
||||
ck::Array<ck::index_t, 5> out_strides)
|
||||
{
|
||||
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
|
||||
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
|
||||
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
|
||||
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
|
||||
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
|
||||
auto conv_to_gemm_transformer = transform_conv(num_dim, ConvSpec, out_lengths, out_strides);
|
||||
auto matrix_padder =
|
||||
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
|
||||
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
|
||||
return b2e;
|
||||
}
|
||||
|
||||
auto get_launch_params_3d(ck::host::Solution solution,
|
||||
ck::Array<ck::index_t, 6> out_lengths,
|
||||
ck::Array<ck::index_t, 6> out_strides)
|
||||
{
|
||||
auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
|
||||
auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
|
||||
auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
|
||||
auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
|
||||
auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
|
||||
ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType);
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType);
|
||||
auto conv_to_gemm_transformer = transform_conv_3d(num_dim, ConvSpec, out_lengths, out_strides);
|
||||
auto matrix_padder =
|
||||
pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
|
||||
auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
|
||||
return b2e;
|
||||
}
|
||||
@@ -0,0 +1,781 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <iterator>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
namespace {
|
||||
|
||||
/*
|
||||
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
|
||||
*
|
||||
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
|
||||
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
|
||||
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
|
||||
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
|
||||
* limitations.
|
||||
*
|
||||
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
|
||||
* returns the 2D index of the tile that it computes. \see
|
||||
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
|
||||
*
|
||||
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
|
||||
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
|
||||
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
|
||||
* impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
|
||||
* \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
|
||||
* computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
|
||||
*
|
||||
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
|
||||
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
|
||||
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
|
||||
*
|
||||
*/
|
||||
template <typename GridwiseGemm,
|
||||
typename AsPointer, // tuples if multi AB, pointers if no
|
||||
typename BsPointer,
|
||||
typename DsPointer,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2ETileMap,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
bool HasMainKBlockLoop,
|
||||
bool isMultiA,
|
||||
bool isMultiB>
|
||||
__device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
|
||||
AsPointer p_as_grid,
|
||||
BsPointer p_bs_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
DsPointer p_ds_grid_grp;
|
||||
|
||||
static constexpr index_t NumDTensor =
|
||||
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size();
|
||||
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
|
||||
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
AsPointer p_as_grid_grp;
|
||||
BsPointer p_bs_grid_grp;
|
||||
|
||||
const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx);
|
||||
|
||||
static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
|
||||
static_for<0, NumATensor, 1>{}(
|
||||
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; });
|
||||
|
||||
const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx);
|
||||
|
||||
static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size();
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; });
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
p_as_grid_grp,
|
||||
p_bs_grid_grp,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
else
|
||||
{
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(
|
||||
p_as_grid + a_batch_offset,
|
||||
p_bs_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
#else
|
||||
ignore = p_as_grid;
|
||||
ignore = p_bs_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = batch_count;
|
||||
ignore = a_grid_desc_k0_m_k1;
|
||||
ignore = b_grid_desc_k0_n_k1;
|
||||
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename AsPointer, // tuples if multi AB, pointers if no
|
||||
typename BsPointer,
|
||||
typename DsPointer,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2ETileMap,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
bool HasMainKBlockLoop,
|
||||
bool isMultiA,
|
||||
bool isMultiB>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
|
||||
AsPointer p_as_grid,
|
||||
BsPointer p_bs_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
|
||||
device_grouped_conv_fwd_multiple_abd_xdl_cshuffle<
|
||||
GridwiseGemm,
|
||||
AsPointer, // tuples if multi AB, pointers if no
|
||||
BsPointer,
|
||||
DsPointer,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfBatch,
|
||||
HasMainKBlockLoop,
|
||||
isMultiA,
|
||||
isMultiB>(p_as_grid,
|
||||
p_bs_grid,
|
||||
p_ds_grid,
|
||||
*p_e_grid,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op,
|
||||
batch_count,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
block_2_ctile_map,
|
||||
compute_ptr_offset_of_batch);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
//
|
||||
// @brief Device Convolution operation.
|
||||
//
|
||||
// Supports:
|
||||
// @li Forward convolution with up to 3 spatial dimentions
|
||||
// @li Input tensor in GNWC data format
|
||||
// @li Weight tensor in GKXC data format
|
||||
// @li Output tensor in GNWK data format
|
||||
//
|
||||
// 1D:
|
||||
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
|
||||
// 2D:
|
||||
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
// 3D:
|
||||
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
|
||||
//
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
ConvolutionForwardSpecialization ConvForwardSpecialization,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t NumGemmKPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
index_t ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEBlockTransferScalarPerVector_NPerBlock,
|
||||
typename ComputeDataType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
ADataType>()), // ComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler()>
|
||||
struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
ComputeDataType>
|
||||
{
|
||||
using DeviceOp = CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
|
||||
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
|
||||
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
|
||||
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto conv_to_gemm_transformer =
|
||||
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay>
|
||||
__host__ __device__ static auto
|
||||
MakeAGridDescriptor_M_K(const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const ck::Array<index_t, NDimSpatial>& input_left_pads,
|
||||
const ck::Array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
|
||||
template <typename BLay>
|
||||
__host__ __device__ static auto
|
||||
MakeBGridDescriptor_N_K(const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
|
||||
{
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides);
|
||||
|
||||
const auto wei_gemmn_gemmk_desc =
|
||||
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
|
||||
|
||||
return wei_gemmn_gemmk_desc;
|
||||
}
|
||||
|
||||
template <typename ELay>
|
||||
__host__ __device__ static auto
|
||||
MakeEGridDescriptor_M_N(const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
|
||||
{
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides);
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
|
||||
return out_gemmm_gemmn_desc;
|
||||
}
|
||||
|
||||
// Shape of Ds and E must be aligned. Strides can be different.
|
||||
// Pass e_g_n_k_wos_lengths for logical broadcast.
|
||||
__host__ __device__ static auto MakeDsGridDescriptor_M_N(
|
||||
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(e_g_n_k_wos_lengths,
|
||||
ds_g_n_k_wos_strides[i]);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
}
|
||||
|
||||
// desc for problem definition
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
|
||||
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
|
||||
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
|
||||
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
|
||||
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
|
||||
|
||||
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
|
||||
// it to it
|
||||
using GemmADataType = std::conditional_t<!isMultiA && isMultiB, Tuple<ADataType>, ADataType>;
|
||||
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
|
||||
|
||||
#define GridwiseGemmTemplateParameters \
|
||||
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
|
||||
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
|
||||
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
|
||||
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
|
||||
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
|
||||
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
|
||||
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
|
||||
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
|
||||
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
|
||||
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
|
||||
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
|
||||
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
|
||||
// Use appropriate gridwise gemm
|
||||
using GridwiseGemm =
|
||||
std::conditional_t<isMultiA || isMultiB,
|
||||
GridwiseGemmMultipleABD_xdl_cshuffle<GridwiseGemmTemplateParameters>,
|
||||
GridwiseGemmMultipleD_xdl_cshuffle<GridwiseGemmTemplateParameters>>;
|
||||
|
||||
// If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers.
|
||||
using APointers =
|
||||
std::conditional_t<isMultiA, ck::Array<const void*, NumATensor>&, const void*>;
|
||||
using BPointers =
|
||||
std::conditional_t<isMultiB, ck::Array<const void*, NumBTensor>&, const void*>;
|
||||
// Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
|
||||
// in initializer list what is required for single const pointer).
|
||||
using AGridPointer = remove_cvref_t<
|
||||
decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>;
|
||||
using BGridPointer = remove_cvref_t<
|
||||
decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>;
|
||||
|
||||
// desc for blockwise copy
|
||||
using AGridDesc_AK0_M_AK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(
|
||||
AGridDesc_M_K{}))>;
|
||||
using BGridDesc_BK0_N_BK1 =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(
|
||||
BGridDesc_N_K{}))>;
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}))>;
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}))>;
|
||||
|
||||
// block-to-e-tile map
|
||||
using Block2ETileMap =
|
||||
remove_cvref_t<decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
|
||||
|
||||
// Argument
|
||||
struct Argument
|
||||
{
|
||||
__device__ __host__ Argument(
|
||||
APointers p_as,
|
||||
BPointers p_bs,
|
||||
const ck::Array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
|
||||
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const ck::Array<index_t, NDimSpatial>& input_left_pads,
|
||||
const ck::Array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
: p_as_grid_{},
|
||||
p_bs_grid_{},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_c_wis_lengths[0]},
|
||||
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)},
|
||||
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides)},
|
||||
ds_grid_desc_m_n_{},
|
||||
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides)},
|
||||
a_grid_desc_ak0_m_ak1_{
|
||||
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
|
||||
b_grid_desc_bk0_n_bk1_{
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
|
||||
compute_ptr_offset_of_batch_{},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
|
||||
a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
|
||||
ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
|
||||
e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
|
||||
e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
// A/B/E Batch Stride
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
static_for<0, NumATensor, 1>{}([&](auto i) {
|
||||
// Init compute_ptr_offset_of_batch_ for multiple AB
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0];
|
||||
|
||||
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
|
||||
// type is not tuple)
|
||||
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmADataType>>;
|
||||
// It is possible that one of the AB is a pointer and one is a tuple.
|
||||
// Then also use multiAB but we have to cast single pointer instead of tuple of
|
||||
// pointer.
|
||||
if constexpr(isMultiA)
|
||||
{
|
||||
// p_as is tuple
|
||||
p_as_grid_(i) = static_cast<const DataType*>(p_as[i.value]);
|
||||
}
|
||||
else
|
||||
{
|
||||
// if MultiB and not MultiA then p_as is single pointer
|
||||
p_as_grid_(i) = static_cast<const DataType*>(p_as);
|
||||
}
|
||||
});
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
// Init compute_ptr_offset_of_batch_ for multiple AB
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0];
|
||||
|
||||
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
|
||||
// It is possible that one of the AB is a pointer and one is a tuple.
|
||||
// Then also use multiAB but we have to cast single pointer instead of tuple of
|
||||
// pointer.
|
||||
if constexpr(isMultiB)
|
||||
{
|
||||
// p_bs is tuple
|
||||
p_bs_grid_(i) = static_cast<const DataType*>(p_bs[i.value]);
|
||||
}
|
||||
else
|
||||
{
|
||||
// if MultiA and not MultiB then p_bs is single pointer
|
||||
p_bs_grid_(i) = static_cast<const DataType*>(p_bs);
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
|
||||
// p_as and p_bs are pointers
|
||||
p_as_grid_(I0) = static_cast<const ADataType*>(p_as);
|
||||
p_bs_grid_(I0) = static_cast<const BDataType*>(p_bs);
|
||||
}
|
||||
|
||||
// populate pointer, batch stride, desc for Ds
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
// D pointer
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
|
||||
|
||||
// D batch stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
|
||||
|
||||
// D desc
|
||||
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
|
||||
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]);
|
||||
});
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
|
||||
|
||||
// populate desc for Ds/E
|
||||
if constexpr(isMultiA || isMultiB)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 =
|
||||
generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number<NumATensor>{});
|
||||
const auto bs_grid_desc_bk0_n_bk1 =
|
||||
generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number<NumBTensor>{});
|
||||
|
||||
if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1,
|
||||
bs_grid_desc_bk0_n_bk1,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_,
|
||||
b_grid_desc_n_k_,
|
||||
ds_grid_desc_m_n_,
|
||||
e_grid_desc_m_n_,
|
||||
block_2_etile_map_))
|
||||
{
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n_);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
// pointers (tuple if multi AB, pointer if no)
|
||||
AGridPointer p_as_grid_;
|
||||
BGridPointer p_bs_grid_;
|
||||
typename GridwiseGemm::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
index_t num_group_;
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_N_K b_grid_desc_n_k_;
|
||||
DsGridDesc_M_N ds_grid_desc_m_n_;
|
||||
EGridDesc_M_N e_grid_desc_m_n_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
|
||||
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
|
||||
compute_ptr_offset_of_batch_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CDEElementwiseOperation cde_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
ck::Array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
|
||||
ck::Array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
|
||||
ck::Array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
ck::Array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
|
||||
ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
|
||||
ck::Array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
|
||||
ck::Array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
|
||||
ck::Array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
ck::Array<index_t, NDimSpatial> conv_filter_dilations_;
|
||||
ck::Array<index_t, NDimSpatial> input_left_pads_;
|
||||
ck::Array<index_t, NDimSpatial> input_right_pads_;
|
||||
};
|
||||
|
||||
static __device__ __host__ auto MakeArgument(
|
||||
APointers p_as,
|
||||
BPointers p_bs,
|
||||
const ck::Array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
|
||||
const ck::Array<ck::Array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const ck::Array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const ck::Array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const ck::Array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const ck::Array<index_t, NDimSpatial>& input_left_pads,
|
||||
const ck::Array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op)
|
||||
{
|
||||
return Argument{p_as,
|
||||
p_bs,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
ds_g_n_k_wos_lengths,
|
||||
ds_g_n_k_wos_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -180,6 +180,19 @@ struct MatrixPadder : public GemmPadder<GemmSpec, MPerTileType, NPerTileType, KP
|
||||
{
|
||||
};
|
||||
|
||||
// function to take in a struct of type MatrixPadder and call the appropriate function to get
|
||||
// the output descriptor at runtime for codegen
|
||||
template <GemmSpecialization GemmSpec,
|
||||
typename MPerTileType,
|
||||
typename NPerTileType,
|
||||
typename KPerTileType,
|
||||
typename CDesc_MRaw_NRaw>
|
||||
auto grid_desc(MatrixPadder<GemmSpec, MPerTileType, NPerTileType, KPerTileType> matrix_padder,
|
||||
CDesc_MRaw_NRaw conv_desc)
|
||||
{
|
||||
auto res = matrix_padder.PadCDescriptor_M_N(conv_desc);
|
||||
return res;
|
||||
}
|
||||
// M/N/KPerTileType could be index_t or Number<>
|
||||
template <bool PadM,
|
||||
bool PadN,
|
||||
|
||||
Reference in New Issue
Block a user