update kernel

This commit is contained in:
joye
2025-06-11 11:10:16 +08:00
parent b90beae6f0
commit 7bc604f06a

View File

@@ -1590,96 +1590,23 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
}
}
// if(filter_y == 3 && filter_x == 3)
// {
// if(stride_y == 1 && stride_x == 1 && pad_y == 1 && pad_x == 1)
// {
// // return kernel_grouped_conv_bwd_data_optimized<ADataType,
// // EDataType,
// // GroupPerBlock,
// // BatchPerBlock,
// // BlockDim,
// // 3,
// // 3,
// // 1,
// // 1,
// // 1,
// // 1>;
// return kernel_grouped_conv_bwd_data_optimized_v2<ADataType,
// EDataType,
// DIRECTION_BACKWARD,
// BlockDim,
// BatchPerBlock,
// GroupPerBlock,
// 4,
// 4,
// 3,
// 3,
// 1,
// 1,
// 1,
// 1,
// 1,
// 1>;
// }
// else if(stride_y == 2 && stride_x == 2 && pad_y == 1 && pad_x == 1)
// {
// return kernel_grouped_conv_bwd_data_optimized<ADataType,
// EDataType,
// GroupPerBlock,
// BatchPerBlock,
// BlockDim,
// 3,
// 3,
// 2,
// 2,
// 1,
// 1>;
// }
// }
// else if(filter_y == 5 && filter_x == 5)
// {
// if(stride_y == 1 && stride_x == 1 && pad_y == 2 && pad_x == 2)
// {
// return kernel_grouped_conv_bwd_data_optimized<ADataType,
// EDataType,
// GroupPerBlock,
// BatchPerBlock,
// BlockDim,
// 5,
// 5,
// 1,
// 1,
// 2,
// 2>;
// }
// else if(stride_y == 2 && stride_x == 2 && pad_y == 2 && pad_x == 2)
// {
// return kernel_grouped_conv_bwd_data_optimized<ADataType,
// EDataType,
// GroupPerBlock,
// BatchPerBlock,
// BlockDim,
// 5,
// 5,
// 2,
// 2,
// 2,
// 2>;
// }
// }
auto default_kernel = &kernel_grouped_conv_bwd_data_optimized<ADataType,
EDataType,
GroupPerBlock,
BatchPerBlock,
512,
5,
5,
1,
1,
2,
2>;
auto default_kernel =
&kernel_grouped_conv_bwd_data_optimized_v2<ADataType,
EDataType,
DIRECTION_BACKWARD,
BlockDim,
BatchPerBlock,
GroupPerBlock,
4,
4,
6,
6,
1,
1,
1,
1,
2,
2>;
return static_cast<decltype(default_kernel)>(nullptr);
};
const auto kernel = kernel_selector();
@@ -1705,58 +1632,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
arg.a_g_n_k_wos_lengths_[NDimSpatial + 2],
arg.e_g_n_c_wis_lengths_[NDimSpatial + 1],
arg.e_g_n_c_wis_lengths_[NDimSpatial + 2],
arg.b_g_k_c_xs_lengths_[NDimSpatial + 1],
arg.b_g_k_c_xs_lengths_[NDimSpatial + 2],
arg.a_g_n_k_wos_lengths_[0],
arg.a_g_n_k_wos_lengths_[1]);
// const auto kernel =
// kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle<
// GridwiseGemm,
// ADataType, // TODO: distiguish A/B datatype
// typename GridwiseGemm::DsGridPointer,
// EDataType,
// AElementwiseOp,
// BElementwiseOp,
// CDEElementwiseOp,
// DeviceOp::AGridDesc_AK0_M_AK1,
// DeviceOp::BGridDesc_BK0_N_BK1,
// DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
// DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
// Block2ETileMap,
// ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
// ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
// has_main_loop,
// ElementOp>;
// return launch_and_time_kernel(
// stream_config,
// kernel,
// dim3(gdx, gdy, gdz),
// dim3(BlockSize),
// 0,
// p_a_grid,
// p_b_grid,
// arg.p_ds_grid_,
// p_e_grid,
// arg.a_element_op_,
// arg.b_element_op_,
// arg.cde_element_op_,
// arg.a_grid_desc_ak0_m_ak1_container_[i],
// arg.b_grid_desc_bk0_n_bk1_container_[i],
// arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
// arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
// arg.block_2_etile_map_container_[i],
// arg.compute_ptr_offset_of_batch_,
// arg.compute_ptr_offset_of_n_,
// arg.k_batch_);
};
// if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, arg.k_batch_))
// {
// ave_time += launch_kernel(integral_constant<bool, true>{});
// }
// else
// {
ave_time += launch_kernel();
// }
}
return ave_time;