mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK Tile] Fix example for conv fwd + bias + clamp (#3235)
* Fix clamp not being applied correctly * Apply group offsets to D tensors --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -5,9 +5,11 @@
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/tensor/tile_elementwise.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
@@ -884,7 +886,8 @@ struct GroupedConvolutionForwardKernel
|
||||
const CDescType& c_desc,
|
||||
const index_t gemm_k,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
const index_t block_idx_n,
|
||||
const CDElementwise& elfunc)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
@@ -907,8 +910,9 @@ struct GroupedConvolutionForwardKernel
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -942,7 +946,8 @@ struct GroupedConvolutionForwardKernel
|
||||
const CDescType& c_desc,
|
||||
const index_t gemm_k,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
const index_t block_idx_n,
|
||||
const CDElementwise& elfunc)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
@@ -964,8 +969,9 @@ struct GroupedConvolutionForwardKernel
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
|
||||
@@ -998,6 +1004,14 @@ struct GroupedConvolutionForwardKernel
|
||||
OutDataType* base_c_ptr =
|
||||
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
|
||||
|
||||
// Apply group offsets to D tensors
|
||||
std::array<const void*, NumDTensor> ds_ptr_with_offsets;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto d) {
|
||||
using DType = std::tuple_element_t<d, DsDataType>;
|
||||
ds_ptr_with_offsets[d] =
|
||||
static_cast<const DType*>(kargs.ds_ptr[d]) + group_offset_c + output_batch_offset;
|
||||
});
|
||||
|
||||
// =====================================================================
|
||||
// Split-image: Map local block to global tile index (if enabled)
|
||||
// =====================================================================
|
||||
@@ -1085,7 +1099,7 @@ struct GroupedConvolutionForwardKernel
|
||||
{
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
ds_ptr_with_offsets,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
@@ -1094,7 +1108,8 @@ struct GroupedConvolutionForwardKernel
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
i_m,
|
||||
i_n);
|
||||
i_n,
|
||||
kargs.elfunc);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -1105,7 +1120,7 @@ struct GroupedConvolutionForwardKernel
|
||||
{
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
ds_ptr_with_offsets,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
a_desc,
|
||||
@@ -1113,7 +1128,8 @@ struct GroupedConvolutionForwardKernel
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
i_m,
|
||||
i_n);
|
||||
i_n,
|
||||
kargs.elfunc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user