From 96a4a5de376ebed481ffccde0a549a7abba55ed8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Mon, 22 Dec 2025 11:05:00 -0500 Subject: [PATCH] Factory bug fixes. --- .../builder/factory/conv_bwd_weight_xdl_factory.hpp | 9 +++++---- .../builder/factory/helpers/ck/conv_tensor_type.hpp | 2 ++ .../builder/factory/helpers/ck/conv_tuning_params.hpp | 1 + 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index 0f726fe67d..db36114997 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -35,6 +35,7 @@ struct ConvBwdWeightXdlFactory static constexpr auto BLOCK = internal::SetThreadBlockInfo(); static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -70,10 +71,10 @@ struct ConvBwdWeightXdlFactory BLOCK.per_block.n, GRIDWISE_GEMM.k0_per_block, GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_xdl, - GRIDWISE_GEMM.n_per_xdl, - GRIDWISE_GEMM.m_xdl_per_wave, - GRIDWISE_GEMM.n_xdl_per_wave, + XDL_PARAMS.m_per_xdl, + XDL_PARAMS.n_per_xdl, + XDL_PARAMS.m_xdl_per_wave, + XDL_PARAMS.n_xdl_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp index d4a470dced..d6b0e06700 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp @@ -190,6 +190,8 @@ struct BwdWeightConvTensorDataTypes using InComputeType = typename decltype(input_types.second)::type; using WeiDataType = typename decltype(weight_types.first)::type; using WeiComputeType = typename decltype(weight_types.second)::type; + using OutDataType = typename decltype(output_types.first)::type; + using OutComputeType = typename decltype(output_types.second)::type; using AccDataType = typename decltype(GetTensorAccumulationType())::type; diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index d7f3b17197..92a7b48ddd 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -169,6 +169,7 @@ consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0; case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0; case ConvSpecialization::ODD_C: return ck_conv_spec::OddC; + case ConvSpecialization::FILTER_3x3: throw "FILTER_3x3 is not supported for backward weight convolution."; default: throw "Unsupported ConvSpecialization"; } }