mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 13:17:36 +00:00
Add option to print out the available instances from CK profiler.
This commit is contained in:
@@ -25,6 +25,47 @@
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
namespace bwd_data
|
||||
{
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InElementOp,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp,
|
||||
typename ComputeDataType>
|
||||
void print_instances()
|
||||
{
|
||||
using DeviceOp =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
|
||||
OutLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
InLayout,
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
ck::Tuple<>,
|
||||
InDataType,
|
||||
OutElementOp,
|
||||
WeiElementOp,
|
||||
InElementOp,
|
||||
ComputeDataType,
|
||||
ComputeDataType>;
|
||||
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
for(const auto& op_ptr : op_ptrs)
|
||||
{
|
||||
std::cout << op_ptr->GetTypeString() << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename OutLayout,
|
||||
typename WeiLayout,
|
||||
|
||||
@@ -29,6 +29,45 @@
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
namespace bwd_weight
|
||||
{
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InElementOp,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp,
|
||||
typename ComputeTypeA,
|
||||
typename ComputeTypeB>
|
||||
void print_instances()
|
||||
{
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
for(const auto& op_ptr : op_ptrs)
|
||||
{
|
||||
std::cout << op_ptr->GetTypeString() << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
|
||||
@@ -30,6 +30,47 @@
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
namespace fwd
|
||||
{
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InElementOp,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp,
|
||||
typename ComputeTypeA,
|
||||
typename ComputeTypeB>
|
||||
void print_instances()
|
||||
{
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
ck::Tuple<>,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
ck::Tuple<>,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
for(const auto& op_ptr : op_ptrs)
|
||||
{
|
||||
std::cout << op_ptr->GetTypeString() << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
|
||||
@@ -52,10 +52,231 @@ static void print_helper_msg()
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void print_available_instances(auto num_dim_spatial_tmp,
|
||||
auto in_layout,
|
||||
auto wei_layout,
|
||||
auto out_layout,
|
||||
auto in_type,
|
||||
auto wei_type,
|
||||
auto out_type,
|
||||
auto compute_type)
|
||||
{
|
||||
constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value;
|
||||
|
||||
using InLayout = decltype(in_layout);
|
||||
using WeiLayout = decltype(wei_layout);
|
||||
using OutLayout = decltype(out_layout);
|
||||
|
||||
using InDataType = decltype(in_type);
|
||||
using WeiDataType = decltype(wei_type);
|
||||
using OutDataType = decltype(out_type);
|
||||
|
||||
using ComputeType = decltype(compute_type);
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
ck::profiler::bwd_data::print_instances<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ComputeType>();
|
||||
}
|
||||
|
||||
void print_bwd_data_instances(auto data_type, auto layout, auto num_dim_spatial)
|
||||
{
|
||||
constexpr auto I2 = ck::Number<2>{};
|
||||
constexpr auto I3 = ck::Number<3>{};
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
using namespace ck::profiler;
|
||||
|
||||
if(num_dim_spatial == 2)
|
||||
{
|
||||
if(layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{});
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3)
|
||||
{
|
||||
if(layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "[CK_PROFILER] This data_type & layout is not implemented" << std::endl;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
{
|
||||
if (argc == 6 && std::string(argv[5]) == "--instances")
|
||||
{
|
||||
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
|
||||
const ck::index_t num_dim_spatial = static_cast<ck::index_t>(std::stoi(argv[4]));
|
||||
|
||||
print_bwd_data_instances(data_type, layout, num_dim_spatial);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 8 for control, 1 for num_dim_spatial
|
||||
if(argc < 9)
|
||||
{
|
||||
|
||||
@@ -64,10 +64,259 @@ static void print_helper_msg()
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
void print_available_instances(auto num_dim_spatial_tmp,
|
||||
auto in_layout,
|
||||
auto wei_layout,
|
||||
auto out_layout,
|
||||
auto in_type,
|
||||
auto wei_type,
|
||||
auto out_type,
|
||||
auto compute_type_a,
|
||||
auto compute_type_b)
|
||||
{
|
||||
constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value;
|
||||
|
||||
using InLayout = decltype(in_layout);
|
||||
using WeiLayout = decltype(wei_layout);
|
||||
using OutLayout = decltype(out_layout);
|
||||
|
||||
using InDataType = decltype(in_type);
|
||||
using WeiDataType = decltype(wei_type);
|
||||
using OutDataType = decltype(out_type);
|
||||
|
||||
using ComputeTypeA = decltype(compute_type_a);
|
||||
using ComputeTypeB = decltype(compute_type_b);
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
ck::profiler::bwd_weight::print_instances<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>();
|
||||
}
|
||||
|
||||
void print_bwd_weight_instances(auto data_type, auto layout, auto num_dim_spatial)
|
||||
{
|
||||
constexpr auto I1 = ck::Number<1>{};
|
||||
constexpr auto I2 = ck::Number<2>{};
|
||||
constexpr auto I3 = ck::Number<3>{};
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
if(num_dim_spatial == 1 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_F32_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_F32_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_F32_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_F32_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::I8_I8_I8)
|
||||
{
|
||||
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_F32_BF16)
|
||||
{
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
if(data_type == ConvDataType::F16_F16_F16_BF8_F8)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::I8_I8_I8)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "[CK_PROFILER] This data_type & layout is not implemented." << std::endl;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
{
|
||||
if (argc == 6 && std::string(argv[5]) == "--instances")
|
||||
{
|
||||
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
|
||||
const ck::index_t num_dim_spatial = static_cast<ck::index_t>(std::stoi(argv[4]));
|
||||
|
||||
print_bwd_weight_instances(data_type, layout, num_dim_spatial);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 8 for control, 1 for num_dim_spatial
|
||||
if(argc < 9)
|
||||
{
|
||||
|
||||
@@ -70,10 +70,293 @@ static void print_helper_msg()
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void print_available_instances(auto num_dim_spatial_tmp,
|
||||
auto in_layout,
|
||||
auto wei_layout,
|
||||
auto out_layout,
|
||||
auto in_type,
|
||||
auto wei_type,
|
||||
auto out_type,
|
||||
auto compute_type_a,
|
||||
auto compute_type_b)
|
||||
{
|
||||
constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value;
|
||||
|
||||
using InLayout = decltype(in_layout);
|
||||
using WeiLayout = decltype(wei_layout);
|
||||
using OutLayout = decltype(out_layout);
|
||||
|
||||
using InDataType = decltype(in_type);
|
||||
using WeiDataType = decltype(wei_type);
|
||||
using OutDataType = decltype(out_type);
|
||||
|
||||
using ComputeTypeA = decltype(compute_type_a);
|
||||
using ComputeTypeB = decltype(compute_type_b);
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
ck::profiler::fwd::print_instances<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>();
|
||||
}
|
||||
|
||||
void print_fwd_instances(auto data_type, auto layout, auto num_dim_spatial)
|
||||
{
|
||||
constexpr auto I1 = ck::Number<1>{};
|
||||
constexpr auto I2 = ck::Number<2>{};
|
||||
constexpr auto I3 = ck::Number<3>{};
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
using TF32 = ck::tf32_t;
|
||||
using INT8 = int8_t;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
// GNHWC_GKYXC_GNHWK
|
||||
if(num_dim_spatial == 1 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::INT8_INT8_INT8)
|
||||
{
|
||||
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::INT8_INT8_INT8)
|
||||
{
|
||||
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::INT8_INT8_INT8)
|
||||
{
|
||||
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
// NHWGC_GKYXC_NHWGK
|
||||
else if(num_dim_spatial == 1 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I1, NWGC{}, GKXC{}, NWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I1, NWGC{}, GKXC{}, NWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::INT8_INT8_INT8)
|
||||
{
|
||||
return print_available_instances(I1, NWGC{}, GKXC{}, NWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::INT8_INT8_INT8)
|
||||
{
|
||||
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::INT8_INT8_INT8)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F8_F8_F8)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}, F8{}, F8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF8_BF8_F8)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, BF8{}, F8{}, BF8{}, BF8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F8_BF8_F8)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, BF8{}, F8{}, F8{}, BF8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF8_F8_F8)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, BF8{}, F8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
// NGCDHW_GKCZYX_NGKDHW
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "[CK_PROFILER] This data_type & layout is not implemented" << std::endl;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
if (argc == 6 && std::string(argv[5]) == "--instances")
|
||||
{
|
||||
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
|
||||
const ck::index_t num_dim_spatial = static_cast<ck::index_t>(std::stoi(argv[4]));
|
||||
|
||||
print_fwd_instances(data_type, layout, num_dim_spatial);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// 8 for control, 1 for num_dim_spatial
|
||||
if(argc < 10)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user