mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
[CK Tile] Fix example for conv fwd + bias + clamp (#3235)
* Fix clamp not being applied correctly * Apply group offsets to D tensors --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -184,11 +184,6 @@ int run_grouped_conv_fwd_bias_clamp_example_with_layouts(
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
// FIXME: Address this issue
|
||||
if(arg_parser.get_int("g") > 1 && init_method == 0)
|
||||
std::cerr << "Adding different bias to different groups yield incorrect results"
|
||||
<< std::endl;
|
||||
|
||||
ck_tile::HostTensor<OutDataType> output_host_ref(out_g_n_k_wos_desc);
|
||||
output_host_ref.SetZero();
|
||||
|
||||
@@ -250,29 +245,27 @@ template <typename Invoker,
|
||||
int run_grouped_conv_fwd_bias_clamp_example_prec_type(
|
||||
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
|
||||
{
|
||||
// using NWGC = ck_tile::tensor_layout::convolution::NWGC;
|
||||
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
|
||||
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
|
||||
|
||||
// using GKXC = ck_tile::tensor_layout::convolution::GKXC;
|
||||
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
|
||||
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
|
||||
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
// using NWGK = ck_tile::tensor_layout::convolution::NWGK;
|
||||
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
|
||||
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
|
||||
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
// FIXME: Fix crash in 1D convolution whem using Ds tensor.
|
||||
throw std::runtime_error("1D Convolution does not support bias.");
|
||||
// return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<1>{},
|
||||
// GemmWarpConfig,
|
||||
// Invoker,
|
||||
// InPrecType,
|
||||
// WeiPrecType,
|
||||
// OutPrecType>(
|
||||
// argc, argv, NWGC{}, GKXC{}, NWGK{});
|
||||
return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<1>{},
|
||||
GemmWarpConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NWGC{}, GKXC{}, NWGK{});
|
||||
}
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
@@ -884,7 +886,8 @@ struct GroupedConvolutionForwardKernel
|
||||
const CDescType& c_desc,
|
||||
const index_t gemm_k,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
const index_t block_idx_n,
|
||||
const CDElementwise& elfunc)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
@@ -907,8 +910,9 @@ struct GroupedConvolutionForwardKernel
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -942,7 +946,8 @@ struct GroupedConvolutionForwardKernel
|
||||
const CDescType& c_desc,
|
||||
const index_t gemm_k,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
const index_t block_idx_n,
|
||||
const CDElementwise& elfunc)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
@@ -964,8 +969,9 @@ struct GroupedConvolutionForwardKernel
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
|
||||
@@ -998,6 +1004,14 @@ struct GroupedConvolutionForwardKernel
|
||||
OutDataType* base_c_ptr =
|
||||
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
|
||||
|
||||
// Apply group offsets to D tensors
|
||||
std::array<const void*, NumDTensor> ds_ptr_with_offsets;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto d) {
|
||||
using DType = std::tuple_element_t<d, DsDataType>;
|
||||
ds_ptr_with_offsets[d] =
|
||||
static_cast<const DType*>(kargs.ds_ptr[d]) + group_offset_c + output_batch_offset;
|
||||
});
|
||||
|
||||
// =====================================================================
|
||||
// Split-image: Map local block to global tile index (if enabled)
|
||||
// =====================================================================
|
||||
@@ -1085,7 +1099,7 @@ struct GroupedConvolutionForwardKernel
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
ds_ptr_with_offsets,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
@@ -1094,7 +1108,8 @@ struct GroupedConvolutionForwardKernel
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
i_m,
|
||||
i_n);
|
||||
i_n,
|
||||
kargs.elfunc);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -1105,7 +1120,7 @@ struct GroupedConvolutionForwardKernel
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
ds_ptr_with_offsets,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
a_desc,
|
||||
@@ -1113,7 +1128,8 @@ struct GroupedConvolutionForwardKernel
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
i_m,
|
||||
i_n);
|
||||
i_n,
|
||||
kargs.elfunc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user