[CK_TILE] Add conv fwd + bias + clamp example (#3012)

* Implement argument passing to element-wise functions for fwd convolution

* Add files for fwd + bias + clamp example

* Implement Bias

* Implement Clamp

* Elementwise function composition

* Composition unit test

* Implement fwd + bias + clamp example

* Simplify argument passing and composition

* elfunc -> bias_and_clamp

* Rename function to specify example

* Move element-wise function instantiation to kernel

* Make bias a runtime tensor

* No ugly namespace aliasing

* Initialize element-wise function on host

* Remove function initialization helper, simplify Compose initialization

* Remove unintended LSP compatibility patch

* Clean up includes and unused code

* Switch names in cshuffle epilogue

* Move CDElementwise to conv traits

* Re-add required include

* Initialize bias in same way as other tensors

* Better type specification for ds pointer

* Disable 1D convolution

* Add warning for non-group-constant bias
This commit is contained in:
Johannes Graner
2025-10-27 18:43:09 +01:00
committed by GitHub
parent 054fdb765c
commit 5c1974065e
11 changed files with 524 additions and 41 deletions

View File

@@ -7,6 +7,7 @@
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
@@ -14,14 +15,18 @@ namespace ck_tile {
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
typename OutDataType,
typename Elfunc = ck_tile::element_wise::PassThrough,
typename Tuple = ck_tile::tuple<>>
CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input,
const HostTensor<WeiDataType>& weight,
HostTensor<OutDataType>& output,
std::vector<ck_tile::long_index_t> conv_strides,
std::vector<ck_tile::long_index_t> conv_dilations,
std::vector<ck_tile::long_index_t> in_left_pads,
std::vector<ck_tile::long_index_t>)
std::vector<ck_tile::long_index_t>,
Elfunc elfunc = Elfunc{},
Tuple ds = {})
{
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
weight.get_num_of_dimension() == NDimSpatial + 3 &&
@@ -52,8 +57,12 @@ CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input
}
}
}
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, wo) = v_acc_converted;
if constexpr(Tuple::size() > 0)
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, wo));
else
elfunc(v_acc, v_acc);
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, wo) = v_acc_out;
};
make_ParallelTensorFunctor(func,
@@ -95,8 +104,12 @@ CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input
}
}
}
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, ho, wo) = v_acc_converted;
if constexpr(Tuple::size() > 0)
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, ho, wo));
else
elfunc(v_acc, v_acc);
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, ho, wo) = v_acc_out;
};
make_ParallelTensorFunctor(func,
@@ -145,8 +158,12 @@ CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input
}
}
}
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, d_o, ho, wo) = v_acc_converted;
if constexpr(Tuple::size() > 0)
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, d_o, ho, wo));
else
elfunc(v_acc, v_acc);
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, d_o, ho, wo) = v_acc_out;
};
make_ParallelTensorFunctor(func,