Add option to print out the available instances from CK profiler.

This commit is contained in:
Ville Pietilä
2026-02-03 08:51:19 -05:00
parent d54eb1d350
commit 1b30d45946
6 changed files with 874 additions and 0 deletions

View File

@@ -25,6 +25,47 @@
namespace ck {
namespace profiler {
namespace bwd_data
{
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename ComputeDataType>
void print_instances()
{
using DeviceOp =
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
OutLayout,
WeiLayout,
ck::Tuple<>,
InLayout,
OutDataType,
WeiDataType,
ck::Tuple<>,
InDataType,
OutElementOp,
WeiElementOp,
InElementOp,
ComputeDataType,
ComputeDataType>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(const auto& op_ptr : op_ptrs)
{
std::cout << op_ptr->GetTypeString() << std::endl;
}
}
}
template <ck::index_t NDimSpatial,
typename OutLayout,
typename WeiLayout,

View File

@@ -29,6 +29,45 @@
namespace ck {
namespace profiler {
namespace bwd_weight
{
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename ComputeTypeA,
typename ComputeTypeB>
void print_instances()
{
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ComputeTypeA,
ComputeTypeB>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(const auto& op_ptr : op_ptrs)
{
std::cout << op_ptr->GetTypeString() << std::endl;
}
}
}
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,

View File

@@ -30,6 +30,47 @@
namespace ck {
namespace profiler {
namespace fwd
{
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename ComputeTypeA,
typename ComputeTypeB>
void print_instances()
{
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<>,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ComputeTypeA,
ComputeTypeB>;
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
for(const auto& op_ptr : op_ptrs)
{
std::cout << op_ptr->GetTypeString() << std::endl;
}
}
}
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,

View File

@@ -52,10 +52,231 @@ static void print_helper_msg()
// clang-format on
}
void print_available_instances(auto num_dim_spatial_tmp,
auto in_layout,
auto wei_layout,
auto out_layout,
auto in_type,
auto wei_type,
auto out_type,
auto compute_type)
{
constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value;
using InLayout = decltype(in_layout);
using WeiLayout = decltype(wei_layout);
using OutLayout = decltype(out_layout);
using InDataType = decltype(in_type);
using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type);
using ComputeType = decltype(compute_type);
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
ck::profiler::bwd_data::print_instances<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
PassThrough,
PassThrough,
PassThrough,
ComputeType>();
}
void print_bwd_data_instances(auto data_type, auto layout, auto num_dim_spatial)
{
constexpr auto I2 = ck::Number<2>{};
constexpr auto I3 = ck::Number<3>{};
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using TF32 = ck::tf32_t;
using namespace ck::tensor_layout::convolution;
using namespace ck::profiler;
if(num_dim_spatial == 2)
{
if(layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{});
}
}
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{});
}
}
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{});
}
}
else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{});
}
}
}
else if(num_dim_spatial == 3)
{
if(layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{});
}
}
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{});
}
}
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{});
}
}
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{});
}
}
}
std::cout << "[CK_PROFILER] This data_type & layout is not implemented" << std::endl;
}
} // namespace
int profile_grouped_conv_bwd_data(int argc, char* argv[])
{
if (argc == 6 && std::string(argv[5]) == "--instances")
{
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
const ck::index_t num_dim_spatial = static_cast<ck::index_t>(std::stoi(argv[4]));
print_bwd_data_instances(data_type, layout, num_dim_spatial);
return 0;
}
// 8 for control, 1 for num_dim_spatial
if(argc < 9)
{

View File

@@ -64,10 +64,259 @@ static void print_helper_msg()
<< std::endl;
}
void print_available_instances(auto num_dim_spatial_tmp,
auto in_layout,
auto wei_layout,
auto out_layout,
auto in_type,
auto wei_type,
auto out_type,
auto compute_type_a,
auto compute_type_b)
{
constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value;
using InLayout = decltype(in_layout);
using WeiLayout = decltype(wei_layout);
using OutLayout = decltype(out_layout);
using InDataType = decltype(in_type);
using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type);
using ComputeTypeA = decltype(compute_type_a);
using ComputeTypeB = decltype(compute_type_b);
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
ck::profiler::bwd_weight::print_instances<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
PassThrough,
PassThrough,
PassThrough,
ComputeTypeA,
ComputeTypeB>();
}
void print_bwd_weight_instances(auto data_type, auto layout, auto num_dim_spatial)
{
constexpr auto I1 = ck::Number<1>{};
constexpr auto I2 = ck::Number<2>{};
constexpr auto I3 = ck::Number<3>{};
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
using TF32 = ck::tf32_t;
using namespace ck::tensor_layout::convolution;
if(num_dim_spatial == 1 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
if(data_type == ConvDataType::BF16_F32_BF16)
{
// fp32 atomic add is used for weight tensor in bf16 kernel
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
if(data_type == ConvDataType::BF16_F32_BF16)
{
// fp32 atomic add is used for weight tensor in bf16 kernel
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
if(data_type == ConvDataType::BF16_F32_BF16)
{
// fp32 atomic add is used for weight tensor in bf16 kernel
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
}
if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
{
if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
if(data_type == ConvDataType::BF16_BF16_BF16)
{
// fp32 atomic add is used for weight tensor in bf16 kernel
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
{
if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
if(data_type == ConvDataType::BF16_F32_BF16)
{
// fp32 atomic add is used for weight tensor in bf16 kernel
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::I8_I8_I8)
{
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
if(data_type == ConvDataType::BF16_F32_BF16)
{
// fp32 atomic add is used for weight tensor in bf16 kernel
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
}
if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
if(data_type == ConvDataType::F16_F16_F16_BF8_F8)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{});
}
else if(data_type == ConvDataType::I8_I8_I8)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
{
if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
{
if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
std::cout << "[CK_PROFILER] This data_type & layout is not implemented." << std::endl;
}
} // namespace
int profile_grouped_conv_bwd_weight(int argc, char* argv[])
{
if (argc == 6 && std::string(argv[5]) == "--instances")
{
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
const ck::index_t num_dim_spatial = static_cast<ck::index_t>(std::stoi(argv[4]));
print_bwd_weight_instances(data_type, layout, num_dim_spatial);
return 0;
}
// 8 for control, 1 for num_dim_spatial
if(argc < 9)
{

View File

@@ -70,10 +70,293 @@ static void print_helper_msg()
// clang-format on
}
void print_available_instances(auto num_dim_spatial_tmp,
auto in_layout,
auto wei_layout,
auto out_layout,
auto in_type,
auto wei_type,
auto out_type,
auto compute_type_a,
auto compute_type_b)
{
constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value;
using InLayout = decltype(in_layout);
using WeiLayout = decltype(wei_layout);
using OutLayout = decltype(out_layout);
using InDataType = decltype(in_type);
using WeiDataType = decltype(wei_type);
using OutDataType = decltype(out_type);
using ComputeTypeA = decltype(compute_type_a);
using ComputeTypeB = decltype(compute_type_b);
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
ck::profiler::fwd::print_instances<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
PassThrough,
PassThrough,
PassThrough,
ComputeTypeA,
ComputeTypeB>();
}
void print_fwd_instances(auto data_type, auto layout, auto num_dim_spatial)
{
constexpr auto I1 = ck::Number<1>{};
constexpr auto I2 = ck::Number<2>{};
constexpr auto I3 = ck::Number<3>{};
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
using TF32 = ck::tf32_t;
using INT8 = int8_t;
using namespace ck::tensor_layout::convolution;
// GNHWC_GKYXC_GNHWK
if(num_dim_spatial == 1 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
// NHWGC_GKYXC_NHWGK
else if(num_dim_spatial == 1 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I1, NWGC{}, GKXC{}, NWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I1, NWGC{}, GKXC{}, NWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return print_available_instances(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::INT8_INT8_INT8)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
}
else if(data_type == ConvDataType::F8_F8_F8)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, F8{}, F8{});
}
else if(data_type == ConvDataType::BF8_BF8_F8)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, BF8{}, F8{}, BF8{}, BF8{});
}
else if(data_type == ConvDataType::F8_BF8_F8)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, BF8{}, F8{}, F8{}, BF8{});
}
else if(data_type == ConvDataType::BF8_F8_F8)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, BF8{}, F8{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
// NGCDHW_GKCZYX_NGKDHW
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
}
}
std::cout << "[CK_PROFILER] This data_type & layout is not implemented" << std::endl;
}
} // namespace
int profile_grouped_conv_fwd(int argc, char* argv[])
{
if (argc == 6 && std::string(argv[5]) == "--instances")
{
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
const ck::index_t num_dim_spatial = static_cast<ck::index_t>(std::stoi(argv[4]));
print_fwd_instances(data_type, layout, num_dim_spatial);
return 0;
}
// 8 for control, 1 for num_dim_spatial
if(argc < 10)
{