diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp index eceb70c05f..523ed4dc20 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_data_impl.hpp @@ -25,6 +25,47 @@ namespace ck { namespace profiler { +namespace bwd_data +{ + template + void print_instances() + { + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + OutElementOp, + WeiElementOp, + InElementOp, + ComputeDataType, + ComputeDataType>; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + for(const auto& op_ptr : op_ptrs) + { + std::cout << op_ptr->GetTypeString() << std::endl; + } + } +} + template + void print_instances() + { + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + for(const auto& op_ptr : op_ptrs) + { + std::cout << op_ptr->GetTypeString() << std::endl; + } + } +} + template + void print_instances() + { + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ComputeTypeA, + ComputeTypeB>; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + for(const auto& op_ptr : op_ptrs) + { + std::cout << op_ptr->GetTypeString() << std::endl; + } + } +} + template (); +} + +void print_bwd_data_instances(auto data_type, auto layout, auto num_dim_spatial) +{ + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using TF32 = ck::tf32_t; + + using namespace ck::tensor_layout::convolution; + using namespace ck::profiler; + + if(num_dim_spatial == 2) + { + if(layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}); + } + } + else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}); + } + } + else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}); + } + } + else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}); + } + } + } + else if(num_dim_spatial == 3) + { + if(layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}); + } + } + else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}); + } + } + else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}); + } + } + else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}); + } + } + } + + std::cout << "[CK_PROFILER] This data_type & layout is not implemented" << std::endl; +} + } // namespace int profile_grouped_conv_bwd_data(int argc, char* argv[]) { + if (argc == 6 && std::string(argv[5]) == "--instances") + { + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const ck::index_t num_dim_spatial = static_cast(std::stoi(argv[4])); + + print_bwd_data_instances(data_type, layout, num_dim_spatial); + return 0; + } + // 8 for control, 1 for num_dim_spatial if(argc < 9) { diff --git a/profiler/src/profile_grouped_conv_bwd_weight.cpp b/profiler/src/profile_grouped_conv_bwd_weight.cpp index c4f154e180..8ba75c28fb 100644 --- a/profiler/src/profile_grouped_conv_bwd_weight.cpp +++ b/profiler/src/profile_grouped_conv_bwd_weight.cpp @@ -64,10 +64,259 @@ static void print_helper_msg() << std::endl; } +void print_available_instances(auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto compute_type_a, + auto compute_type_b) +{ + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using ComputeTypeA = decltype(compute_type_a); + using ComputeTypeB = decltype(compute_type_b); + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + ck::profiler::bwd_weight::print_instances(); +} + +void print_bwd_weight_instances(auto data_type, auto layout, auto num_dim_spatial) +{ + constexpr auto I1 = ck::Number<1>{}; + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using F8 = ck::f8_t; + using BF8 = ck::bf8_t; + using TF32 = ck::tf32_t; + + using namespace ck::tensor_layout::convolution; + + if(num_dim_spatial == 1 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + if(data_type == ConvDataType::BF16_F32_BF16) + { + // fp32 atomic add is used for weight tensor in bf16 kernel + return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + if(data_type == ConvDataType::BF16_F32_BF16) + { + // fp32 atomic add is used for weight tensor in bf16 kernel + return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + if(data_type == ConvDataType::BF16_F32_BF16) + { + // fp32 atomic add is used for weight tensor in bf16 kernel + return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); + } + if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) + { + if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + if(data_type == ConvDataType::BF16_BF16_BF16) + { + // fp32 atomic add is used for weight tensor in bf16 kernel + return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } + else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) + { + if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + if(data_type == ConvDataType::BF16_F32_BF16) + { + // fp32 atomic add is used for weight tensor in bf16 kernel + return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::I8_I8_I8) + { + return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + if(data_type == ConvDataType::BF16_F32_BF16) + { + // fp32 atomic add is used for weight tensor in bf16 kernel + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{}); + } + if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + if(data_type == ConvDataType::F16_F16_F16_BF8_F8) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{}); + } + else if(data_type == ConvDataType::I8_I8_I8) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) + { + if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) + { + if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + + std::cout << "[CK_PROFILER] This data_type & layout is not implemented." << std::endl; +} + } // namespace int profile_grouped_conv_bwd_weight(int argc, char* argv[]) { + if (argc == 6 && std::string(argv[5]) == "--instances") + { + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const ck::index_t num_dim_spatial = static_cast(std::stoi(argv[4])); + + print_bwd_weight_instances(data_type, layout, num_dim_spatial); + return 0; + } + // 8 for control, 1 for num_dim_spatial if(argc < 9) { diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index 4319d849c8..d070be1b99 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -70,10 +70,293 @@ static void print_helper_msg() // clang-format on } +void print_available_instances(auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto compute_type_a, + auto compute_type_b) +{ + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using ComputeTypeA = decltype(compute_type_a); + using ComputeTypeB = decltype(compute_type_b); + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + ck::profiler::fwd::print_instances(); +} + +void print_fwd_instances(auto data_type, auto layout, auto num_dim_spatial) +{ + constexpr auto I1 = ck::Number<1>{}; + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using F8 = ck::f8_t; + using BF8 = ck::bf8_t; + using TF32 = ck::tf32_t; + using INT8 = int8_t; + + using namespace ck::tensor_layout::convolution; + + // GNHWC_GKYXC_GNHWK + if(num_dim_spatial == 1 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + // NHWGC_GKYXC_NHWGK + else if(num_dim_spatial == 1 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I1, NWGC{}, GKXC{}, NWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I1, NWGC{}, GKXC{}, NWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return print_available_instances(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::INT8_INT8_INT8) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); + } + else if(data_type == ConvDataType::F8_F8_F8) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, F8{}, F8{}); + } + else if(data_type == ConvDataType::BF8_BF8_F8) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, BF8{}, F8{}, BF8{}, BF8{}); + } + else if(data_type == ConvDataType::F8_BF8_F8) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, BF8{}, F8{}, F8{}, BF8{}); + } + else if(data_type == ConvDataType::BF8_F8_F8) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, BF8{}, F8{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + // NGCDHW_GKCZYX_NGKDHW + else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { + return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); + } + } + + std::cout << "[CK_PROFILER] This data_type & layout is not implemented" << std::endl; +} + } // namespace int profile_grouped_conv_fwd(int argc, char* argv[]) { + if (argc == 6 && std::string(argv[5]) == "--instances") + { + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const ck::index_t num_dim_spatial = static_cast(std::stoi(argv[4])); + + print_fwd_instances(data_type, layout, num_dim_spatial); + return 0; + } + // 8 for control, 1 for num_dim_spatial if(argc < 10) {