[POST MERGE PR] Enable grouped conv bwd wei bf16 NGCHW (#1594)

This commit is contained in:
Bartłomiej Kocot
2024-10-23 12:02:33 +02:00
committed by GitHub
parent 4d5248e2d1
commit cedccd59c9
8 changed files with 161 additions and 5 deletions

View File

@@ -182,6 +182,10 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
// fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
}
if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
{
@@ -210,11 +214,6 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
// fp32 atomic add is used for weight tensor in bf16 kernel
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
}
if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
else if(data_type == ConvDataType::I8_I8_I8)
{
return profile(