mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
NHWC conv 2d: fwd bfp16/int8, Device level tuning and host API (#73)
* add fwd bf16 conv
* change tunning parametor
* add int8 for conv fwd
* remove comments
* change tunning parametor for int8
* change init int8 example
* add test for conv2d fwd
* change device operation file pos because merge develop
* fwd int8 use reference
* test_conv_fwd use reference
* add braket for if statement
* rename fwd example name
* remove StaticBufferOfVectorTypeV2
* tweak example
Co-authored-by: ltqin <letaoqin@amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: 880fbee957]
This commit is contained in:
@@ -21,7 +21,7 @@ void host_conv_nchw_kcyx_nkhw(const Tensor<TIn>& in,
|
||||
constexpr auto I1 = ck::Number<1>{};
|
||||
|
||||
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
double v = 0;
|
||||
float v = 0;
|
||||
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
|
||||
{
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
|
||||
@@ -33,13 +33,13 @@ void host_conv_nchw_kcyx_nkhw(const Tensor<TIn>& in,
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
v += static_cast<const double>(in(n, c, hi, wi)) *
|
||||
static_cast<const double>(wei(k, c, y, x));
|
||||
v += ck::type_convert<float>(in(n, c, hi, wi)) *
|
||||
ck::type_convert<float>(wei(k, c, y, x));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out(n, k, ho, wo) = v;
|
||||
out(n, k, ho, wo) = ck::type_convert<TOut>(v);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw,
|
||||
|
||||
Reference in New Issue
Block a user