mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[CK_TILE] Add conv fwd + bias + clamp example (#3012)
* Implement argument passing to element-wise functions for fwd convolution * Add files for fwd + bias + clamp example * Implement Bias * Implement Clamp * Elementwise function composition * Composition unit test * Implement fwd + bias + clamp example * Simplify argument passing and composition * elfunc -> bias_and_clamp * Rename function to specify example * Move element-wise function instantiation to kernel * Make bias a runtime tensor * No ugly namespace aliasing * Initialize element-wise function on host * Remove function initialization helper, simplify Compose initialization * Remove unintended LSP compatibility patch * Clean up includes and unused code * Switch names in cshuffle epilogue * Move CDElementwise to conv traits * Re-add required include * Initialize bias in same way as other tensors * Better type specification for ds pointer * Disable 1D convolution * Add warning for non-group-constant bias
This commit is contained in:
@@ -1540,6 +1540,23 @@ struct Logistic
|
||||
const float alpha_;
|
||||
};
|
||||
|
||||
struct Clamp
|
||||
{
|
||||
CK_TILE_HOST_DEVICE Clamp(float lower = std::numeric_limits<float>::lowest(),
|
||||
float upper = std::numeric_limits<float>::max())
|
||||
: lower_(lower), upper_(upper) {};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(T& y, const T& x) const
|
||||
{
|
||||
T lower = ck_tile::type_convert<T>(lower_);
|
||||
T upper = ck_tile::type_convert<T>(upper_);
|
||||
y = ck_tile::clamp(x, lower, upper);
|
||||
}
|
||||
|
||||
float lower_, upper_;
|
||||
};
|
||||
|
||||
struct ConvInvscale
|
||||
{
|
||||
static constexpr const char* name = "ConvInvscale";
|
||||
@@ -1629,6 +1646,55 @@ struct Cast
|
||||
};
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Compose two unary element-wise functions into one.
|
||||
*
|
||||
*
|
||||
* @note The Ds tensor can be used by at most one of the composed functions.
|
||||
* This holds even if compositions are chained:
|
||||
* In `Compose<FA, Compose<FB, FC>>`, only one of `FA`, `FB`, or `FC` can use
|
||||
* the Ds tensor.
|
||||
*
|
||||
* @tparam FuncA The first function to be applied.
|
||||
* @tparam FuncB The second function to be applied.
|
||||
* @tparam FuncADs Whether `FuncA` uses the Ds tensor.
|
||||
* @tparam FuncBDs Whether `FuncB` uses the Ds tensor.
|
||||
*/
|
||||
template <typename FuncA, typename FuncB, bool FuncADs = false, bool FuncBDs = false>
|
||||
struct Compose
|
||||
{
|
||||
static_assert(!(FuncADs && FuncBDs), "Only one composed function may use the Ds tensor.");
|
||||
|
||||
CK_TILE_HOST_DEVICE Compose(FuncA func_a_ = FuncA{}, FuncB func_b_ = FuncB{})
|
||||
: func_a(func_a_), func_b(func_b_)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename AIn, typename BOut, typename AOut = AIn, typename... ADs>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(BOut& y, const AIn& x, const ADs&... ds) const
|
||||
{
|
||||
AOut tmp;
|
||||
if constexpr(FuncADs)
|
||||
{
|
||||
func_a(tmp, x, ds...);
|
||||
func_b(y, tmp);
|
||||
}
|
||||
else if constexpr(FuncBDs)
|
||||
{
|
||||
func_a(tmp, x);
|
||||
func_b(y, tmp, ds...);
|
||||
}
|
||||
else
|
||||
{
|
||||
func_a(tmp, x);
|
||||
func_b(y, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
const FuncA func_a;
|
||||
const FuncB func_b;
|
||||
};
|
||||
|
||||
// support fastconvert of int8 to fp16
|
||||
#if 0
|
||||
template <typename InputDataType, typename OutputDataType, index_t RegPackNumber>
|
||||
|
||||
Reference in New Issue
Block a user