mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
fix Issue 291 (#294)
* rename for typeconvert functor
* refine code
[ROCm/composable_kernel commit: 4634b12043]
This commit is contained in:
@@ -433,7 +433,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
using namespace ck;
|
||||
|
||||
const index_t Di = input_spatial_lengths[0];
|
||||
const index_t Hi = input_spatial_lengths[2];
|
||||
const index_t Hi = input_spatial_lengths[1];
|
||||
const index_t Wi = input_spatial_lengths[2];
|
||||
|
||||
const index_t Do = output_spatial_lengths[0];
|
||||
@@ -671,11 +671,14 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
return PadDescriptor_M0_1d(desc, gridSize, blockSize);
|
||||
}
|
||||
|
||||
using TypeConvertFunctor =
|
||||
using TypeConvertFp32ToBf16Functor =
|
||||
ck::tensor_operation::element_wise::UnaryTypeConvert<ck::bhalf_t, float>;
|
||||
using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1));
|
||||
using GridwiseUEltwise =
|
||||
GridwiseUnaryElementwise_1D<AccDataType, InDataType, GridDesc_M0, TypeConvertFunctor, 4>;
|
||||
using GridDesc_M0 = decltype(MakeDescriptor_M0<1>({1}, {1}, 1, 1));
|
||||
using GridwiseUEltwise = GridwiseUnaryElementwise_1D<AccDataType,
|
||||
InDataType,
|
||||
GridDesc_M0,
|
||||
TypeConvertFp32ToBf16Functor,
|
||||
4>;
|
||||
|
||||
using ABCGridDescs = decltype(GetABCGridDesc<NumDimSpatial>());
|
||||
|
||||
@@ -979,33 +982,32 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
|
||||
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
const auto run_conv = [&](const auto& kernel) {
|
||||
hipGetErrorString(hipMemset(
|
||||
arg.p_c_grid_,
|
||||
0,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
|
||||
sizeof(CDataType)));
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
// run kernel for bf16 with splitk
|
||||
@@ -1016,22 +1018,21 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() *
|
||||
sizeof(AccDataType)));
|
||||
|
||||
ave_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
static_cast<AccDataType*>(arg.p_workspace_),
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
static_cast<AccDataType*>(arg.p_workspace_),
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
};
|
||||
|
||||
// kernel for type conversion
|
||||
@@ -1059,7 +1060,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
// run kernel for type conversion
|
||||
void* p_c_grid_tmp_ = static_cast<void*>(arg.p_c_grid_);
|
||||
InDataType* p_c_grid_tmp_bf16_ = static_cast<InDataType*>(p_c_grid_tmp_);
|
||||
const auto Run_type_convert = [&](const auto& kernel) {
|
||||
const auto run_type_convert = [&](const auto& kernel) {
|
||||
float elapsed_time =
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
@@ -1070,14 +1071,15 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
p_c_grid_tmp_bf16_,
|
||||
a_grid_desc_m0_,
|
||||
b_grid_desc_m0_,
|
||||
TypeConvertFunctor{});
|
||||
TypeConvertFp32ToBf16Functor{});
|
||||
return elapsed_time;
|
||||
};
|
||||
|
||||
if constexpr(std::is_same<InDataType, ck::bhalf_t>::value)
|
||||
{
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
if(kbatch == 1)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
@@ -1092,9 +1094,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
true>;
|
||||
has_main_loop>;
|
||||
|
||||
Run(kernel);
|
||||
return run_conv(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1103,7 +1105,7 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
AccDataType,
|
||||
InDataType,
|
||||
GridDesc_M0,
|
||||
TypeConvertFunctor>;
|
||||
TypeConvertFp32ToBf16Functor>;
|
||||
|
||||
const auto kernel_conv = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemmAtomicAddFloatBf16Splitk,
|
||||
@@ -1117,56 +1119,28 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
true>;
|
||||
has_main_loop>;
|
||||
|
||||
run_bf16_splitk(kernel_conv);
|
||||
ave_time += Run_type_convert(kernel_type_convert);
|
||||
float elapsed_time = 0;
|
||||
elapsed_time += run_bf16_splitk(kernel_conv);
|
||||
elapsed_time += run_type_convert(kernel_type_convert);
|
||||
return elapsed_time;
|
||||
}
|
||||
};
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
ave_time = launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kbatch == 1)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemmAtomicAddFloatBf16Splitk,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
AccDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
run_bf16_splitk(kernel);
|
||||
}
|
||||
ave_time = launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
if(kbatch == 1)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
@@ -1181,9 +1155,9 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
true>;
|
||||
has_main_loop>;
|
||||
|
||||
Run(kernel);
|
||||
return run_conv(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1199,49 +1173,18 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
true>;
|
||||
has_main_loop>;
|
||||
|
||||
Run(kernel);
|
||||
return run_conv(kernel);
|
||||
}
|
||||
};
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
ave_time = launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
if(kbatch == 1)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemmAtomicAdd,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
Run(kernel);
|
||||
}
|
||||
ave_time = launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user