[Ck_tile] smoothquant (#1617)

* fix compile error

* fix typo of padding

* Add smoothquant op

* Add smoothquant instance library

* refine type

* add test script

* Re-generate smoothquant.hpp

* Always use 'current year' in copyright

* use Generic2dBlockShape instead

* Add vector = 8 instance back

* Find exe path automatically

* Simplify the api condition

* Remove debugging code

* update year

* Add blank line between function declaration

* explicitly cast return value to dim3

* refine return value

* Fix default warmup and repeat value

* Add comment

* refactor sommthquant cmake

* Add README

* Fix typo

---------

Co-authored-by: Po Yen, Chen <PoYen.Chen@amd.com>
This commit is contained in:
rocking
2024-11-01 13:51:56 +08:00
committed by GitHub
parent 550248deec
commit fbd654545a
62 changed files with 1758 additions and 219 deletions

View File

@@ -18,7 +18,7 @@ struct AddRmsnormRdquantTypeConfig<ck_tile::half_t>
using BDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t;
using XDataType = ck_tile::half_t;
using YScaleDataType = ck_tile::half_t;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
};
@@ -30,7 +30,7 @@ struct AddRmsnormRdquantTypeConfig<ck_tile::bf16_t>
using BDataType = ck_tile::bf16_t;
using GammaDataType = ck_tile::bf16_t;
using XDataType = ck_tile::bf16_t;
using YScaleDataType = ck_tile::bf16_t;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
};
@@ -101,7 +101,7 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::AddRmsnorm2dRdquantShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveX = kSaveX_;

View File

@@ -66,7 +66,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using BDataType = DataType;
using GammaDataType = DataType;
using XDataType = DataType;
using YScaleDataType = DataType;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
@@ -99,12 +99,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
constexpr bool kThreePass = true;
using BlockWarps = ck_tile::sequence<2, 2>;
using BlockTile = ck_tile::sequence<2, 128>;
using BlockWarps = ck_tile::sequence<4, 1>;
using BlockTile = ck_tile::sequence<4, 128>;
using WarpTile = ck_tile::sequence<1, 64>;
using Vector = ck_tile::sequence<1, 1>;
using Shape = ck_tile::AddRmsnorm2dRdquantShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Problem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem<ADataType,
BDataType,
GammaDataType,

View File

@@ -28,7 +28,6 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/,
add_rmsnorm2d_rdquant_fwd_args a,
const ck_tile::stream_config& s)
{
#if 1
float r = -1;
// clang-format off
// rm rn tm tn vn pd x 3p
@@ -128,9 +127,6 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/,
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, true, true>>(s, a);
}
return r;
#else
return add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 1, 2, 128, 8, true, true, false>>(s, a);
#endif
// clang-format on
}
@@ -139,7 +135,6 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t,
const ck_tile::stream_config& s)
{
float r = -1;
// Only support instance of save_x == true for now
assert(t.save_x);
if(t.data_type.compare("fp16") == 0)
@@ -150,8 +145,6 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t,
{
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::bf16_t>(t, a, s);
}
if(r < 0)
else
throw std::runtime_error("Without supported instances!");
return r;
}

View File

@@ -1,6 +1,5 @@
# run from top of ck folder
EXE=build/bin/tile_add_rmsnorm2d_rdquant_fwd
#!/bin/sh
EXE="$(find . -name tile_add_rmsnorm2d_rdquant_fwd -type f | head -n 1)"
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000

View File

@@ -1,6 +1,5 @@
#!/bin/sh
# call from top of CK folder
EXE=./build/bin/tile_add_rmsnorm2d_rdquant_fwd
EXE="$(find . -name tile_add_rmsnorm2d_rdquant_fwd -type f | head -n 1)"
for pr_i in "fp16" "bf16" ; do
$EXE -prec=$pr_i -m=99 -n=13