mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
[CK_TILE] Add conv fwd + bias + clamp example (#3012)
* Implement argument passing to element-wise functions for fwd convolution * Add files for fwd + bias + clamp example * Implement Bias * Implement Clamp * Elementwise function composition * Composition unit test * Implement fwd + bias + clamp example * Simplify argument passing and composition * elfunc -> bias_and_clamp * Rename function to specify example * Move element-wise function instantiation to kernel * Make bias a runtime tensor * No ugly namespace aliasing * Initialize element-wise function on host * Remove function initialization helper, simplify Compose initialization * Remove unintended LSP compatibility patch * Clean up includes and unused code * Switch names in cshuffle epilogue * Move CDElementwise to conv traits * Re-add required include * Initialize bias in same way as other tensors * Better type specification for ds pointer * Disable 1D convolution * Add warning for non-group-constant bias
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -14,7 +15,7 @@ namespace ck_tile {
|
||||
/// This structure is passed to Grouped Convolution Kernels when creating kernel
|
||||
/// arguments object. It contain all necessary information required to
|
||||
/// build proper kernel argument and launch kernel on GPU.
|
||||
template <typename InPtr, typename WeiPtr, typename OutPtr>
|
||||
template <typename InPtr, typename WeiPtr, typename OutPtr, typename CDElementwise>
|
||||
struct GroupedConvHostArgs : public conv::ConvParam
|
||||
{
|
||||
CK_TILE_HOST GroupedConvHostArgs() = delete;
|
||||
@@ -23,13 +24,15 @@ struct GroupedConvHostArgs : public conv::ConvParam
|
||||
WeiPtr wei_ptr_,
|
||||
const std::vector<const void*> ds_ptr_,
|
||||
OutPtr out_ptr_,
|
||||
index_t k_batch_)
|
||||
index_t k_batch_,
|
||||
CDElementwise elfunc_ = CDElementwise{})
|
||||
: conv::ConvParam(conv_param),
|
||||
in_ptr(in_ptr_),
|
||||
wei_ptr(wei_ptr_),
|
||||
ds_ptr(ds_ptr_),
|
||||
out_ptr(out_ptr_),
|
||||
k_batch(k_batch_)
|
||||
k_batch(k_batch_),
|
||||
elfunc(elfunc_)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -38,11 +41,17 @@ struct GroupedConvHostArgs : public conv::ConvParam
|
||||
const std::vector<const void*> ds_ptr;
|
||||
OutPtr out_ptr;
|
||||
index_t k_batch;
|
||||
const CDElementwise elfunc;
|
||||
};
|
||||
|
||||
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*>;
|
||||
using GroupedConvBwdWeightHostArgs = GroupedConvHostArgs<const void*, void*, const void*>;
|
||||
using GroupedConvBwdDataHostArgs = GroupedConvHostArgs<void*, const void*, const void*>;
|
||||
using PassThrough = ck_tile::element_wise::PassThrough;
|
||||
|
||||
template <typename CDElementwise = PassThrough>
|
||||
using GroupedConvFwdHostArgs = GroupedConvHostArgs<const void*, const void*, void*, CDElementwise>;
|
||||
using GroupedConvBwdWeightHostArgs =
|
||||
GroupedConvHostArgs<const void*, void*, const void*, PassThrough>;
|
||||
using GroupedConvBwdDataHostArgs =
|
||||
GroupedConvHostArgs<void*, const void*, const void*, PassThrough>;
|
||||
|
||||
template <index_t NDimSpatial_,
|
||||
ConvolutionSpecialization ConvSpecialization_,
|
||||
@@ -50,9 +59,10 @@ template <index_t NDimSpatial_,
|
||||
typename WeiLayout_,
|
||||
typename DsLayout_,
|
||||
typename OutLayout_,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1,
|
||||
index_t VectorSizeC_ = 1>
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1,
|
||||
index_t VectorSizeC_ = 1,
|
||||
typename CDElementwise_ = PassThrough>
|
||||
struct GroupedConvTraits
|
||||
{
|
||||
private:
|
||||
@@ -70,6 +80,7 @@ struct GroupedConvTraits
|
||||
using WeiLayout = WeiLayout_;
|
||||
using DsLayout = DsLayout_;
|
||||
using OutLayout = OutLayout_;
|
||||
using CDElementwise = CDElementwise_;
|
||||
using GroupedConvImplicitGemmTraitsFwd =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
|
||||
Reference in New Issue
Block a user