Add bfp16/int8 support into XDL GEMM operator (#50)

* init StaticBufferV2

* clean

* adopt old output stage for staticBufferV2

* clean

* remove hack

* clean

* clean

* add parameters

* clean code

* move c_buffer alloc into blockwise gemm

* add adaptors for m/n_thread_data_on_grid

* tweak gemm

* adjust blockwise_gemm_xdlops

* tweak

* update conv

* update script

* adding bwd 1x1

* update script

* adding 1x1 bwd

* debugging bwd 1x1 failure

* update script

* update script

* test

* test v100

* add bf16_1k

* clang-format

* clean

* add bfp16 for gfx908

* add verification

* clean up

* clean code

* restore bfl16

* clean

* add bfp16 support into gemm_driver

* apply new generator to other drivers

* add int8 support

* cleanb

* clean

* clean

* clean

Co-authored-by: Chao Liu <chao.liu2@amd.com>
Co-authored-by: Chao Liu <lc.roy86@gmail.com>
Co-authored-by: root <root@hayabusa6111.amd.com>

[ROCm/composable_kernel commit: 3737bb039a]
This commit is contained in:
zjing14
2021-11-15 10:24:39 -06:00
committed by GitHub
parent 8791d26e52
commit 456f5306df
11 changed files with 668 additions and 332 deletions

View File

@@ -325,30 +325,30 @@ int main(int argc, char* argv[])
// no initialization
break;
case 1:
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
out.GenerateTensorValue(GeneratorTensor_1<out_data_t>{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
break;
case 2:
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out.GenerateTensorValue(GeneratorTensor_1<out_data_t>{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
break;
case 3:
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
break;
case 4:
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
break;
case 5:
out.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
out.GenerateTensorValue(GeneratorTensor_3<out_data_t>{0.0, 1.0}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_3<in_data_t>{-0.5, 0.5}, num_thread);
break;
default:
out.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{1, 5}, num_thread);
auto gen_wei = [](auto... is) {
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
return GeneratorTensor_2<in_data_t>{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
};
wei.GenerateTensorValue(gen_wei, num_thread);
}

View File

@@ -80,13 +80,29 @@ void host_convolution_forward(const Tensor<TIn>& in,
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
wi < in.mDesc.GetLengths()[3])
{
v += static_cast<const double>(in(n, c, hi, wi)) *
static_cast<const double>(wei(k, c, y, x));
if constexpr(is_same<TIn, ushort>::value)
{
v += bfloat16_to_float(in(n, c, hi, wi)) *
bfloat16_to_float(wei(k, c, y, x));
}
else
{
v += static_cast<const double>(in(n, c, hi, wi)) *
static_cast<const double>(wei(k, c, y, x));
}
}
}
}
}
out(n, k, ho, wo) = v;
if constexpr(is_same<TOut, ushort>::value)
{
out(n, k, ho, wo) = float_to_bfloat16(v);
}
else
{
out(n, k, ho, wo) = v;
}
};
auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
@@ -102,13 +118,28 @@ void host_convolution_forward(const Tensor<TIn>& in,
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
wi < in.mDesc.GetLengths()[2])
{
v += static_cast<const double>(in(n, hi, wi, c)) *
static_cast<const double>(wei(k, y, x, c));
if constexpr(is_same<TIn, ushort>::value)
{
v += bfloat16_to_float(in(n, hi, wi, c)) *
bfloat16_to_float(wei(k, y, x, c));
}
else
{
v += static_cast<const double>(in(n, hi, wi, c)) *
static_cast<const double>(wei(k, y, x, c));
}
}
}
}
}
out(n, ho, wo, k) = v;
if constexpr(is_same<TOut, ushort>::value)
{
out(n, ho, wo, k) = float_to_bfloat16(v);
}
else
{
out(n, ho, wo, k) = v;
}
};
if(layout == ConvTensorLayout::NCHW)
@@ -226,10 +257,14 @@ int main(int argc, char* argv[])
using in_data_t = float;
using acc_data_t = float;
using out_data_t = float;
#elif 1
#elif 0
using in_data_t = half_t;
using acc_data_t = float;
using out_data_t = half_t;
#elif 1
using in_data_t = ushort;
using acc_data_t = float;
using out_data_t = ushort;
#elif 1
using in_data_t = int8_t;
using acc_data_t = int32_t;
@@ -295,30 +330,30 @@ int main(int argc, char* argv[])
// no initialization
break;
case 1:
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
in.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
break;
case 2:
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
in.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
break;
case 3:
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
break;
case 4:
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
break;
case 5:
in.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
in.GenerateTensorValue(GeneratorTensor_3<in_data_t>{0.0, 1.0}, num_thread);
wei.GenerateTensorValue(GeneratorTensor_3<in_data_t>{-0.5, 0.5}, num_thread);
break;
default:
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{1, 5}, num_thread);
auto gen_wei = [](auto... is) {
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
return GeneratorTensor_2<in_data_t>{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
};
wei.GenerateTensorValue(gen_wei, num_thread);
}

View File

@@ -297,30 +297,30 @@ int main(int argc, char* argv[])
// no initialization
break;
case 1:
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
in.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
out.GenerateTensorValue(GeneratorTensor_1<out_data_t>{}, num_thread);
break;
case 2:
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
in.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{-5, 5}, num_thread);
break;
case 3:
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
out.GenerateTensorValue(GeneratorTensor_1<out_data_t>{}, num_thread);
break;
case 4:
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{-5, 5}, num_thread);
break;
case 5:
in.GenerateTensorValue(GeneratorTensor_3<float>{-0.1, 0.1}, num_thread);
out.GenerateTensorValue(GeneratorTensor_3<float>{-0.1, 0.1}, num_thread);
in.GenerateTensorValue(GeneratorTensor_3<in_data_t>{-0.1, 0.1}, num_thread);
out.GenerateTensorValue(GeneratorTensor_3<out_data_t>{-0.1, 0.1}, num_thread);
break;
default:
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{1, 5}, num_thread);
auto gen_out = [](auto... is) {
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
return GeneratorTensor_2<out_data_t>{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
};
out.GenerateTensorValue(gen_out, num_thread);
}

View File

@@ -239,10 +239,14 @@ int main(int argc, char* argv[])
using ab_data_t = float;
using acc_data_t = float;
using c_data_t = float;
#elif 1
#elif 0
using ab_data_t = half_t;
using acc_data_t = float;
using c_data_t = half_t;
#elif 1
using ab_data_t = ushort;
using acc_data_t = float;
using c_data_t = ushort;
#elif 1
using ab_data_t = int8_t;
using acc_data_t = int32_t;
@@ -321,24 +325,24 @@ int main(int argc, char* argv[])
// no initialization
break;
case 1:
a.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
b.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
a.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
b.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
break;
case 2:
a.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
a.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
b.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
break;
case 3:
a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
b.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
a.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
b.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
break;
case 4:
a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
a.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
b.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
break;
default:
a.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
b.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
a.GenerateTensorValue(GeneratorTensor_3<ab_data_t>{0.0, 1.0}, num_thread);
b.GenerateTensorValue(GeneratorTensor_3<ab_data_t>{-0.5, 0.5}, num_thread);
}
#if USE_GEMM_XDL_MK_KN_MN