[CK Tile] Fix example for conv fwd + bias + clamp (#3235)

* Fix clamp not being applied correctly

* Apply group offsets to D tensors

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
Johannes Graner
2025-11-24 07:36:26 +01:00
committed by GitHub
parent f6c999bddb
commit 096f0a3b23
2 changed files with 36 additions and 27 deletions

View File

@@ -5,9 +5,11 @@
#include <iostream>
#include <string>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/core/utility/env.hpp"
@@ -884,7 +886,8 @@ struct GroupedConvolutionForwardKernel
const CDescType& c_desc,
const index_t gemm_k,
const index_t block_idx_m,
const index_t block_idx_n)
const index_t block_idx_n,
const CDElementwise& elfunc)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
@@ -907,8 +910,9 @@ struct GroupedConvolutionForwardKernel
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
/**
@@ -942,7 +946,8 @@ struct GroupedConvolutionForwardKernel
const CDescType& c_desc,
const index_t gemm_k,
const index_t block_idx_m,
const index_t block_idx_n)
const index_t block_idx_n,
const CDElementwise& elfunc)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
@@ -964,8 +969,9 @@ struct GroupedConvolutionForwardKernel
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
@@ -998,6 +1004,14 @@ struct GroupedConvolutionForwardKernel
OutDataType* base_c_ptr =
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
// Apply group offsets to D tensors
std::array<const void*, NumDTensor> ds_ptr_with_offsets;
static_for<0, NumDTensor, 1>{}([&](auto d) {
using DType = std::tuple_element_t<d, DsDataType>;
ds_ptr_with_offsets[d] =
static_cast<const DType*>(kargs.ds_ptr[d]) + group_offset_c + output_batch_offset;
});
// =====================================================================
// Split-image: Map local block to global tile index (if enabled)
// =====================================================================
@@ -1085,7 +1099,7 @@ struct GroupedConvolutionForwardKernel
{
RunGemm2LDS(a_ptr,
b_ptr,
kargs.ds_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr_0,
smem_ptr_1,
@@ -1094,7 +1108,8 @@ struct GroupedConvolutionForwardKernel
c_desc,
kargs.GemmK,
i_m,
i_n);
i_n,
kargs.elfunc);
}
}
else
@@ -1105,7 +1120,7 @@ struct GroupedConvolutionForwardKernel
{
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr_0,
a_desc,
@@ -1113,7 +1128,8 @@ struct GroupedConvolutionForwardKernel
c_desc,
kargs.GemmK,
i_m,
i_n);
i_n,
kargs.elfunc);
}
}
}