mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Profile resnet layout fixes (#3360)
This commit is contained in:
@@ -75,13 +75,13 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
|
||||
is_same<decltype(layout), ck::tensor_layout::convolution::KCYX>::value ||
|
||||
is_same<decltype(layout), ck::tensor_layout::convolution::NKHW>::value)
|
||||
{
|
||||
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz});
|
||||
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}, layout);
|
||||
}
|
||||
else if constexpr(is_same<decltype(layout), tensor_layout::convolution::NHWC>::value ||
|
||||
is_same<decltype(layout), tensor_layout::convolution::KYXC>::value ||
|
||||
is_same<decltype(layout), tensor_layout::convolution::NHWK>::value)
|
||||
{
|
||||
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_});
|
||||
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}, layout);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -75,13 +75,13 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
|
||||
is_same<decltype(layout), ck::tensor_layout::convolution::KCYX>::value ||
|
||||
is_same<decltype(layout), ck::tensor_layout::convolution::NKHW>::value)
|
||||
{
|
||||
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz});
|
||||
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, H * W, W, 1_uz}, layout);
|
||||
}
|
||||
else if constexpr(is_same<decltype(layout), tensor_layout::convolution::NHWC>::value ||
|
||||
is_same<decltype(layout), tensor_layout::convolution::KYXC>::value ||
|
||||
is_same<decltype(layout), tensor_layout::convolution::NHWK>::value)
|
||||
{
|
||||
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_});
|
||||
return HostTensorDescriptor({N_, C_, H, W}, {C_ * H * W, 1_uz, W * C_, C_}, layout);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user