Merge commit '1c433c64ec5254d202b7cbf4b8b0e98678ea2a4f' into develop

This commit is contained in:
assistant-librarian[bot]
2026-01-06 09:16:30 +00:00
parent 2285a8345a
commit 5d0010c4b9
9 changed files with 349 additions and 60 deletions

View File

@@ -125,9 +125,9 @@ struct ReferenceFactory
// Direct Run method (simpler interface, direction-agnostic)
template <typename InPtrType, typename WeiPtrType, typename OutPtrType>
static void Run(InPtrType input,
WeiPtrType weight,
OutPtrType output,
static void Run(InPtrType* input,
WeiPtrType* weight,
OutPtrType* output,
int G,
int N,
int K,
@@ -142,9 +142,9 @@ struct ReferenceFactory
if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
ck_tile::naive_grouped_conv_fwd<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
input,
weight,
output,
static_cast<const InDataType*>(input),
static_cast<const WeiDataType*>(weight),
static_cast<OutDataType*>(output),
G,
N,
K,
@@ -160,9 +160,9 @@ struct ReferenceFactory
{
ck_tile::
naive_grouped_conv_bwd_data<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
input,
weight,
output,
static_cast<InDataType*>(input),
static_cast<const WeiDataType*>(weight),
static_cast<const OutDataType*>(output),
G,
N,
K,
@@ -179,19 +179,20 @@ struct ReferenceFactory
ck_tile::naive_grouped_conv_bwd_weight<SPATIAL_DIM,
InDataType,
WeiDataType,
OutDataType>(input,
weight,
output,
G,
N,
K,
C,
input_spatial,
filter_spatial,
output_spatial,
strides,
dilations,
left_pads);
OutDataType>(
static_cast<const InDataType*>(input),
static_cast<WeiDataType*>(weight),
static_cast<const OutDataType*>(output),
G,
N,
K,
C,
input_spatial,
filter_spatial,
output_spatial,
strides,
dilations,
left_pads);
}
}