[CK_TILE] Add conv fwd + bias + clamp example (#3012)

* Implement argument passing to element-wise functions for fwd convolution

* Add files for fwd + bias + clamp example

* Implement Bias

* Implement Clamp

* Elementwise function composition

* Composition unit test

* Implement fwd + bias + clamp example

* Simplify argument passing and composition

* elfunc -> bias_and_clamp

* Rename function to specify example

* Move element-wise function instantiation to kernel

* Make bias a runtime tensor

* No ugly namespace aliasing

* Initialize element-wise function on host

* Remove function initialization helper, simplify Compose initialization

* Remove unintended LSP compatibility patch

* Clean up includes and unused code

* Switch names in cshuffle epilogue

* Move CDElementwise to conv traits

* Re-add required include

* Initialize bias in same way as other tensors

* Better type specification for ds pointer

* Disable 1D convolution

* Add warning for non-group-constant bias
This commit is contained in:
Johannes Graner
2025-10-27 18:43:09 +01:00
committed by GitHub
parent 054fdb765c
commit 5c1974065e
11 changed files with 524 additions and 41 deletions

View File

@@ -7,10 +7,12 @@
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/core/utility/env.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp"
#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp"
@@ -28,6 +30,7 @@ struct GroupedConvFwdKernelArgs
GroupedConvTraitsType_::VectorSizeB,
GroupedConvTraitsType_::VectorSizeC,
true>; // Split N enabled
using CDElementwise = typename GroupedConvTraitsType_::CDElementwise;
static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
template <
@@ -38,7 +41,8 @@ struct GroupedConvFwdKernelArgs
std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
bool>::type = false>
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& args)
: elfunc(args.elfunc)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
@@ -121,7 +125,8 @@ struct GroupedConvFwdKernelArgs
std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
bool>::type = false>
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& args)
: elfunc(args.elfunc)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
@@ -213,7 +218,8 @@ struct GroupedConvFwdKernelArgs
std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
bool>::type = false>
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& args)
: elfunc(args.elfunc)
{
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
static_cast<index_t>(args.N_),
@@ -335,6 +341,7 @@ struct GroupedConvFwdKernelArgs
const void* in_ptr;
const void* wei_ptr;
std::array<const void*, NumDTensor> ds_ptr;
const CDElementwise elfunc;
void* out_ptr;
AGridDescMK a_grid_desc_m_k;
@@ -423,6 +430,8 @@ struct GroupedConvolutionForwardKernel
// Below type is actually accumulation data type - the output of block GEMM.
using OutDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using CDElementwise = typename EpiloguePipeline::CDElementwise;
using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs<GroupedConvTraitsType_>;
// TODO: Enable this
@@ -458,7 +467,7 @@ struct GroupedConvolutionForwardKernel
}
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs)
MakeKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& hostArgs)
{
return GroupedConvFwdKernelArgsSpecialized(hostArgs);
}
@@ -636,7 +645,7 @@ struct GroupedConvolutionForwardKernel
"Not supported!");
return make_tensor_view<address_space_enum::global>(
static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
static_cast<const OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
},
number<NumDTensor>{});
@@ -765,8 +774,9 @@ struct GroupedConvolutionForwardKernel
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{kargs.elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
/**