mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
[rocm-libraries] ROCm/rocm-libraries#4797 (commit 1a30400)
[CK_TILE] Add CK Tile bwd weight profiler MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation To compare old CK and CK Tile, we need to extend the current CK profiler to support running also CK Tile instance with the same API. In order to have the same instance coverage in CK Tile compared to the old CK, I've added code generation from old CK configurations to CK Tile instances using the CK Builder. ## Technical Details - The codegen python script for CK Tile fwd convs is extended to support also bwd weight and bwd data. - The generated instances are added to the CMake build (target `device_grouped_conv_bwd_weight_tile_instance`s). - A new profiler op (`grouped_conv_bwd_weight_tile`) has been added to the CK Profiler.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
fc1e1a5155
commit
ae4e632c7d
@@ -55,10 +55,263 @@ static void print_helper_msg()
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void print_bwd_data_instances(ConvDataType data_type,
|
||||
ConvLayout layout,
|
||||
ck::index_t num_dim_spatial)
|
||||
{
|
||||
auto print_available_instances = [&](auto num_dim_spatial_tmp,
|
||||
auto in_layout,
|
||||
auto wei_layout,
|
||||
auto out_layout,
|
||||
auto in_type,
|
||||
auto wei_type,
|
||||
auto out_type,
|
||||
auto compute_type) {
|
||||
constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value;
|
||||
|
||||
using InLayout = decltype(in_layout);
|
||||
using WeiLayout = decltype(wei_layout);
|
||||
using OutLayout = decltype(out_layout);
|
||||
|
||||
using InDataType = decltype(in_type);
|
||||
using WeiDataType = decltype(wei_type);
|
||||
using OutDataType = decltype(out_type);
|
||||
|
||||
using ComputeType = decltype(compute_type);
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
ck::profiler::bwd_data::print_instances<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ComputeType>();
|
||||
};
|
||||
|
||||
constexpr auto I2 = ck::Number<2>{};
|
||||
constexpr auto I3 = ck::Number<3>{};
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using TF32 = ck::tf32_t;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
using namespace ck::profiler;
|
||||
|
||||
if(num_dim_spatial == 2)
|
||||
{
|
||||
if(layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(
|
||||
I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(
|
||||
I2, GNHWC{}, GKYXC{}, GNHWK{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(
|
||||
I2, GNHWC{}, GKYXC{}, GNHWK{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(
|
||||
I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(
|
||||
I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(
|
||||
I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(
|
||||
I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(
|
||||
I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(
|
||||
I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(
|
||||
I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(
|
||||
I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(
|
||||
I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(
|
||||
I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(
|
||||
I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(
|
||||
I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(
|
||||
I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{});
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3)
|
||||
{
|
||||
if(layout == ConvLayout::GNHWC_GKYXC_GNHWK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(
|
||||
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(
|
||||
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(
|
||||
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(
|
||||
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(
|
||||
I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(
|
||||
I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(
|
||||
I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(
|
||||
I3, NGCDHW{}, GKZYXC{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{});
|
||||
}
|
||||
}
|
||||
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
|
||||
{
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
return print_available_instances(
|
||||
I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, F32{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
return print_available_instances(
|
||||
I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F16{}, F16{}, F16{}, F16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
return print_available_instances(
|
||||
I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32_TF32)
|
||||
{
|
||||
return print_available_instances(
|
||||
I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "[CK_PROFILER] This data_type & layout is not implemented" << std::endl;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
{
|
||||
if(argc == 6 && std::string(argv[5]) == "--instances")
|
||||
{
|
||||
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
|
||||
const ck::index_t num_dim_spatial = static_cast<ck::index_t>(std::stoi(argv[4]));
|
||||
|
||||
print_bwd_data_instances(data_type, layout, num_dim_spatial);
|
||||
return 0;
|
||||
}
|
||||
// Parse optional named arguments first
|
||||
ck::index_t instance_index = -1;
|
||||
bool list_instances = false;
|
||||
|
||||
Reference in New Issue
Block a user