mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_BUILDER] Add bwd weight factories (#3509)
* Add placeholder test. * Initial conv bwd weight factory. * Conv builder test refactoring. * Add missing pieces to bwd weight factory. * Improve compile time erros message when no matching factory is found. * Use amcro to ensure automatic macthing between concepts are their string representations. * Improve compile time diagnostics. * Small improvements. * Improve missing member/wrong type compile-time errors. * Improve compile time diagnostics. * Concept bug fixes. * Remove debug assert. * Update algorithm signature diagnostics. * Factory bug fixes. * First functional version of bwd weight conv factory. * Refactor handing of GEMM-K batch template parameter in conv bwd weight factory. * Concept improvements. * Improve concept diagnostics. * Introduve a common size type for concepts. * Update compiletime diagnostics to use the size type. * Update conv specialization enum. * Fix fwd conv builder tests. * Fix smoke tests. * Separate bwd weigth and bwd data tests into separate targets. * Clean-up CK Tile builder tests. * Add bwd weight XDL CShuffle V3 factory. * Build conv bwd weigth v3 instances successfully. * Add instance traits for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3. * Test fix. * Add instance traits for bwd weight algorithms. * Add unit tests for instance strings. * Build new instance traits unit tests but exclude WMMA for now. * Added factory for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle. * Conv bwd weight DL factory. * Final implementation for bwd weight DL factory. * Add test for creating DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle instance. * Add factory for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle * Treat ref algorithm the same way as real algorithms in the dispatcher. * Refactor large tensor support and WMMA configuration. * Add factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffleV3. * Update Readme. * Fix WMMA bwd weight tests. * Added factory and tests for DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3. * Factory and tests for DeviceGroupedConvBwdWeight_Wmma_CShuffle. * Dispatching for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffle. * Add factory for DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 * Fix DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 factory and compute types for input and output tensor in bwd weigth convs. * Fix fwd factories after refactoring. * clang-format * Move compile-time diagnostics to a separate branch. * Fix ref algorithm dispatching. * Fix smoke tests. * clang-format * Fix factory for regular WMMA conv bwd weight. * Clarify builder Readme. * Remove obsolete test file. * Fix test after merge. * clang-format * Remove the C++26 extensions. * Unify conv elementwise ops and layout definitions for fwd and bwd directions. * Remove old layout and elementwise ops. * Unify handling of conv tensor types between fwd and bwd directions. * Unify block transfer for fwd and bwd directions. Rename ThreadSliceDim to ThreadClusterRank. * Make BlockTransferDescriptor concept parametrized. Introduce a common TileTransferParameters concept for conv algorithms. * clang-format --------- Co-authored-by: Ville Pietilä <>
This commit is contained in:
@@ -50,7 +50,7 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3(
|
||||
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d(
|
||||
typename GridwiseGemm::Argument karg,
|
||||
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
|
||||
@@ -858,30 +858,32 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
const auto kernel =
|
||||
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
const auto kernel =
|
||||
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
@@ -897,30 +899,32 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
|
||||
{
|
||||
if(gemm_arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
const auto kernel =
|
||||
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
const auto kernel =
|
||||
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
|
||||
GridwiseGemm,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,19 +52,20 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
kernel_batched_gemm_xdlops_bwd_weight_multiple_d(
|
||||
const FloatA* __restrict__ p_a_grid,
|
||||
const FloatB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
|
||||
if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
|
||||
@@ -568,7 +569,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
int max_occupancy = 0;
|
||||
hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_occupancy,
|
||||
kernel_batched_gemm_xdlops_bwd_weight<
|
||||
kernel_batched_gemm_xdlops_bwd_weight_multiple_d<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -841,7 +842,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
|
||||
};
|
||||
|
||||
const auto kernel = kernel_batched_gemm_xdlops_bwd_weight<
|
||||
const auto kernel = kernel_batched_gemm_xdlops_bwd_weight_multiple_d<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
|
||||
Reference in New Issue
Block a user