[CK_TILE] Add conv fwd + bias + clamp example (#3012)

* Implement argument passing to element-wise functions for fwd convolution

* Add files for fwd + bias + clamp example

* Implement Bias

* Implement Clamp

* Elementwise function composition

* Composition unit test

* Implement fwd + bias + clamp example

* Simplify argument passing and composition

* elfunc -> bias_and_clamp

* Rename function to specify example

* Move element-wise function instantiation to kernel

* Make bias a runtime tensor

* No ugly namespace aliasing

* Initialize element-wise function on host

* Remove function initialization helper, simplify Compose initialization

* Remove unintended LSP compatibility patch

* Clean up includes and unused code

* Switch names in cshuffle epilogue

* Move CDElementwise to conv traits

* Re-add required include

* Initialize bias in same way as other tensors

* Better type specification for ds pointer

* Disable 1D convolution

* Add warning for non-group-constant bias
This commit is contained in:
Johannes Graner
2025-10-27 18:43:09 +01:00
committed by GitHub
parent 054fdb765c
commit 5c1974065e
11 changed files with 524 additions and 41 deletions

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
namespace ck_tile {
@@ -14,7 +15,7 @@ namespace ck_tile {
/// This structure is passed to Grouped Convolution Kernels when creating kernel
/// arguments object. It contain all necessary information required to
/// build proper kernel argument and launch kernel on GPU.
template <typename InPtr, typename WeiPtr, typename OutPtr>
template <typename InPtr, typename WeiPtr, typename OutPtr, typename CDElementwise>
struct GroupedConvHostArgs : public conv::ConvParam
{
CK_TILE_HOST GroupedConvHostArgs() = delete;
@@ -23,13 +24,15 @@ struct GroupedConvHostArgs : public conv::ConvParam
WeiPtr wei_ptr_,
const std::vector<const void*> ds_ptr_,
OutPtr out_ptr_,
index_t k_batch_)
index_t k_batch_,
CDElementwise elfunc_ = CDElementwise{})
: conv::ConvParam(conv_param),
in_ptr(in_ptr_),
wei_ptr(wei_ptr_),
ds_ptr(ds_ptr_),
out_ptr(out_ptr_),
k_batch(k_batch_)
k_batch(k_batch_),
elfunc(elfunc_)
{
}
@@ -38,11 +41,17 @@ struct GroupedConvHostArgs : public conv::ConvParam
const std::vector<const void*> ds_ptr;
OutPtr out_ptr;
index_t k_batch;
const CDElementwise elfunc;
};
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*>;
using GroupedConvBwdWeightHostArgs = GroupedConvHostArgs<const void*, void*, const void*>;
using GroupedConvBwdDataHostArgs = GroupedConvHostArgs<void*, const void*, const void*>;
using PassThrough = ck_tile::element_wise::PassThrough;
template <typename CDElementwise = PassThrough>
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*, CDElementwise>;
using GroupedConvBwdWeightHostArgs =
GroupedConvHostArgs<const void*, void*, const void*, PassThrough>;
using GroupedConvBwdDataHostArgs =
GroupedConvHostArgs<void*, const void*, const void*, PassThrough>;
template <index_t NDimSpatial_,
ConvolutionSpecialization ConvSpecialization_,
@@ -50,9 +59,10 @@ template <index_t NDimSpatial_,
typename WeiLayout_,
typename DsLayout_,
typename OutLayout_,
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
index_t VectorSizeC_ = 1>
index_t VectorSizeA_ = 1,
index_t VectorSizeB_ = 1,
index_t VectorSizeC_ = 1,
typename CDElementwise_ = PassThrough>
struct GroupedConvTraits
{
private:
@@ -70,6 +80,7 @@ struct GroupedConvTraits
using WeiLayout = WeiLayout_;
using DsLayout = DsLayout_;
using OutLayout = OutLayout_;
using CDElementwise = CDElementwise_;
using GroupedConvImplicitGemmTraitsFwd =
TileGemmTraits<true,
true,