[CK_TILE] Add conv fwd + bias + clamp example (#3012)

* Implement argument passing to element-wise functions for fwd convolution

* Add files for fwd + bias + clamp example

* Implement Bias

* Implement Clamp

* Elementwise function composition

* Composition unit test

* Implement fwd + bias + clamp example

* Simplify argument passing and composition

* elfunc -> bias_and_clamp

* Rename function to specify example

* Move element-wise function instantiation to kernel

* Make bias a runtime tensor

* No ugly namespace aliasing

* Initialize element-wise function on host

* Remove function initialization helper, simplify Compose initialization

* Remove unintended LSP compatibility patch

* Clean up includes and unused code

* Switch names in cshuffle epilogue

* Move CDElementwise to conv traits

* Re-add required include

* Initialize bias in same way as other tensors

* Better type specification for ds pointer

* Disable 1D convolution

* Add warning for non-group-constant bias
This commit is contained in:
Johannes Graner
2025-10-27 18:43:09 +01:00
committed by GitHub
parent 054fdb765c
commit 5c1974065e
11 changed files with 524 additions and 41 deletions

View File

@@ -1540,6 +1540,23 @@ struct Logistic
const float alpha_;
};
struct Clamp
{
CK_TILE_HOST_DEVICE Clamp(float lower = std::numeric_limits<float>::lowest(),
float upper = std::numeric_limits<float>::max())
: lower_(lower), upper_(upper) {};
template <typename T>
CK_TILE_HOST_DEVICE constexpr void operator()(T& y, const T& x) const
{
T lower = ck_tile::type_convert<T>(lower_);
T upper = ck_tile::type_convert<T>(upper_);
y = ck_tile::clamp(x, lower, upper);
}
float lower_, upper_;
};
struct ConvInvscale
{
static constexpr const char* name = "ConvInvscale";
@@ -1629,6 +1646,55 @@ struct Cast
};
};
/**
* @brief Compose two unary element-wise functions into one.
*
*
* @note The Ds tensor can be used by at most one of the composed functions.
* This holds even if compositions are chained:
* In `Compose<FA, Compose<FB, FC>>`, only one of `FA`, `FB`, or `FC` can use
* the Ds tensor.
*
* @tparam FuncA The first function to be applied.
* @tparam FuncB The second function to be applied.
* @tparam FuncADs Whether `FuncA` uses the Ds tensor.
* @tparam FuncBDs Whether `FuncB` uses the Ds tensor.
*/
template <typename FuncA, typename FuncB, bool FuncADs = false, bool FuncBDs = false>
struct Compose
{
static_assert(!(FuncADs && FuncBDs), "Only one composed function may use the Ds tensor.");
CK_TILE_HOST_DEVICE Compose(FuncA func_a_ = FuncA{}, FuncB func_b_ = FuncB{})
: func_a(func_a_), func_b(func_b_)
{
}
template <typename AIn, typename BOut, typename AOut = AIn, typename... ADs>
CK_TILE_HOST_DEVICE constexpr void operator()(BOut& y, const AIn& x, const ADs&... ds) const
{
AOut tmp;
if constexpr(FuncADs)
{
func_a(tmp, x, ds...);
func_b(y, tmp);
}
else if constexpr(FuncBDs)
{
func_a(tmp, x);
func_b(y, tmp, ds...);
}
else
{
func_a(tmp, x);
func_b(y, tmp);
}
}
const FuncA func_a;
const FuncB func_b;
};
// support fastconvert of int8 to fp16
#if 0
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>