mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
[CK_TILE] Add conv fwd + bias + clamp example (#3012)
* Implement argument passing to element-wise functions for fwd convolution
* Add files for fwd + bias + clamp example
* Implement Bias
* Implement Clamp
* Elementwise function composition
* Composition unit test
* Implement fwd + bias + clamp example
* Simplify argument passing and composition
* elfunc -> bias_and_clamp
* Rename function to specify example
* Move element-wise function instantiation to kernel
* Make bias a runtime tensor
* No ugly namespace aliasing
* Initialize element-wise function on host
* Remove function initialization helper, simplify Compose initialization
* Remove unintended LSP compatibility patch
* Clean up includes and unused code
* Switch names in cshuffle epilogue
* Move CDElementwise to conv traits
* Re-add required include
* Initialize bias in same way as other tensors
* Better type specification for ds pointer
* Disable 1D convolution
* Add warning for non-group-constant bias
[ROCm/composable_kernel commit: 5c1974065e]
This commit is contained in:
@@ -33,6 +33,14 @@ struct elementwise_op_traits<ck_tile::element_wise::Relu>
|
||||
static constexpr int num_inputs = 1;
|
||||
};
|
||||
|
||||
using NegRelu =
|
||||
ck_tile::element_wise::Compose<ck_tile::element_wise::Relu, ck_tile::element_wise::Neg>;
|
||||
template <>
|
||||
struct elementwise_op_traits<NegRelu>
|
||||
{
|
||||
static constexpr int num_inputs = 1;
|
||||
};
|
||||
|
||||
template <std::size_t D, typename F>
|
||||
auto make_uniform_array_with_factory(F&& factory)
|
||||
{
|
||||
@@ -194,7 +202,11 @@ using TestConfig_F16_Add = std::tuple<ck_tile::half_t,
|
||||
Shape1_BlockTile,
|
||||
Shape1_WarpTile>;
|
||||
|
||||
using TestTypes = ::testing::Types<TestConfig_F32_Add, TestConfig_F32_Relu, TestConfig_F16_Add>;
|
||||
using TestConfig_F32_Neg_Relu =
|
||||
std::tuple<float, float, float, NegRelu, Shape1_BlockWarps, Shape1_BlockTile, Shape1_WarpTile>;
|
||||
|
||||
using TestTypes = ::testing::
|
||||
Types<TestConfig_F32_Add, TestConfig_F32_Relu, TestConfig_F16_Add, TestConfig_F32_Neg_Relu>;
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileElementwise, TestTypes);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user