Factory bug fixes.

This commit is contained in:
Ville Pietilä
2025-12-22 11:05:00 -05:00
parent a8e7edd814
commit 96a4a5de37
3 changed files with 8 additions and 4 deletions

View File

@@ -35,6 +35,7 @@ struct ConvBwdWeightXdlFactory
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
@@ -70,10 +71,10 @@ struct ConvBwdWeightXdlFactory
BLOCK.per_block.n,
GRIDWISE_GEMM.k0_per_block,
GRIDWISE_GEMM.k1,
GRIDWISE_GEMM.m_per_xdl,
GRIDWISE_GEMM.n_per_xdl,
GRIDWISE_GEMM.m_xdl_per_wave,
GRIDWISE_GEMM.n_xdl_per_wave,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,

View File

@@ -190,6 +190,8 @@ struct BwdWeightConvTensorDataTypes
using InComputeType = typename decltype(input_types.second)::type;
using WeiDataType = typename decltype(weight_types.first)::type;
using WeiComputeType = typename decltype(weight_types.second)::type;
using OutDataType = typename decltype(output_types.first)::type;
using OutComputeType = typename decltype(output_types.second)::type;
using AccDataType =
typename decltype(GetTensorAccumulationType<Signature.accumulation_data_type,
Signature.data_type>())::type;

View File

@@ -169,6 +169,7 @@ consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization
case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
case ConvSpecialization::ODD_C: return ck_conv_spec::OddC;
case ConvSpecialization::FILTER_3x3: throw "FILTER_3x3 is not supported for backward weight convolution.";
default: throw "Unsupported ConvSpecialization";
}
}