mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Factory bug fixes.
This commit is contained in:
@@ -35,6 +35,7 @@ struct ConvBwdWeightXdlFactory
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
@@ -70,10 +71,10 @@ struct ConvBwdWeightXdlFactory
|
||||
BLOCK.per_block.n,
|
||||
GRIDWISE_GEMM.k0_per_block,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
|
||||
@@ -190,6 +190,8 @@ struct BwdWeightConvTensorDataTypes
|
||||
using InComputeType = typename decltype(input_types.second)::type;
|
||||
using WeiDataType = typename decltype(weight_types.first)::type;
|
||||
using WeiComputeType = typename decltype(weight_types.second)::type;
|
||||
using OutDataType = typename decltype(output_types.first)::type;
|
||||
using OutComputeType = typename decltype(output_types.second)::type;
|
||||
using AccDataType =
|
||||
typename decltype(GetTensorAccumulationType<Signature.accumulation_data_type,
|
||||
Signature.data_type>())::type;
|
||||
|
||||
@@ -169,6 +169,7 @@ consteval ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization
|
||||
case ConvSpecialization::FILTER_1X1_PAD0: return ck_conv_spec::Filter1x1Pad0;
|
||||
case ConvSpecialization::FILTER_1X1_STRIDE1_PAD0: return ck_conv_spec::Filter1x1Stride1Pad0;
|
||||
case ConvSpecialization::ODD_C: return ck_conv_spec::OddC;
|
||||
case ConvSpecialization::FILTER_3x3: throw "FILTER_3x3 is not supported for backward weight convolution.";
|
||||
default: throw "Unsupported ConvSpecialization";
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user