Merge commit '096f0a3b23a49ffaef1e2dbed74bf366e36ad15c' into develop

This commit is contained in:
assistant-librarian[bot]
2025-11-24 07:13:25 +00:00
parent 8abfd83364
commit 7bd01a9f5f
2 changed files with 36 additions and 27 deletions

View File

@@ -5,9 +5,11 @@
#include <iostream>
#include <string>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "ck_tile/core/utility/env.hpp"
@@ -884,7 +886,8 @@ struct GroupedConvolutionForwardKernel
const CDescType& c_desc,
const index_t gemm_k,
const index_t block_idx_m,
const index_t block_idx_n)
const index_t block_idx_n,
const CDElementwise& elfunc)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
@@ -907,8 +910,9 @@ struct GroupedConvolutionForwardKernel
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
/**
@@ -942,7 +946,8 @@ struct GroupedConvolutionForwardKernel
const CDescType& c_desc,
const index_t gemm_k,
const index_t block_idx_m,
const index_t block_idx_n)
const index_t block_idx_n,
const CDElementwise& elfunc)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
@@ -964,8 +969,9 @@ struct GroupedConvolutionForwardKernel
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
@@ -998,6 +1004,14 @@ struct GroupedConvolutionForwardKernel
OutDataType* base_c_ptr =
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
// Apply group offsets to D tensors
std::array<const void*, NumDTensor> ds_ptr_with_offsets;
static_for<0, NumDTensor, 1>{}([&](auto d) {
using DType = std::tuple_element_t<d, DsDataType>;
ds_ptr_with_offsets[d] =
static_cast<const DType*>(kargs.ds_ptr[d]) + group_offset_c + output_batch_offset;
});
// =====================================================================
// Split-image: Map local block to global tile index (if enabled)
// =====================================================================
@@ -1085,7 +1099,7 @@ struct GroupedConvolutionForwardKernel
{
RunGemm2LDS(a_ptr,
b_ptr,
kargs.ds_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr_0,
smem_ptr_1,
@@ -1094,7 +1108,8 @@ struct GroupedConvolutionForwardKernel
c_desc,
kargs.GemmK,
i_m,
i_n);
i_n,
kargs.elfunc);
}
}
else
@@ -1105,7 +1120,7 @@ struct GroupedConvolutionForwardKernel
{
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
ds_ptr_with_offsets,
c_ptr,
smem_ptr_0,
a_desc,
@@ -1113,7 +1128,8 @@ struct GroupedConvolutionForwardKernel
c_desc,
kargs.GemmK,
i_m,
i_n);
i_n,
kargs.elfunc);
}
}
}