mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[Ck_tile] smoothquant (#1617)
* fix compile error * fix typo of padding * Add smoothquant op * Add smoothquant instance library * refine type * add test script * Re-generate smoothquant.hpp * Always use 'current year' in copyright * use Generic2dBlockShape instead * Add vector = 8 instance back * Find exe path automatically * Simplify the api condition * Remove debugging code * update year * Add blank line between function declaration * explicitly cast return value to dim3 * refine return value * Fix default warmup and repeat value * Add comment * refactor sommthquant cmake * Add README * Fix typo --------- Co-authored-by: Po Yen, Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -18,7 +18,7 @@ struct AddRmsnormRdquantTypeConfig<ck_tile::half_t>
|
||||
using BDataType = ck_tile::half_t;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YScaleDataType = ck_tile::half_t;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
@@ -30,7 +30,7 @@ struct AddRmsnormRdquantTypeConfig<ck_tile::bf16_t>
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using GammaDataType = ck_tile::bf16_t;
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using YScaleDataType = ck_tile::bf16_t;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
@@ -101,7 +101,7 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::AddRmsnorm2dRdquantShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveX = kSaveX_;
|
||||
|
||||
@@ -66,7 +66,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using BDataType = DataType;
|
||||
using GammaDataType = DataType;
|
||||
using XDataType = DataType;
|
||||
using YScaleDataType = DataType;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
|
||||
@@ -99,12 +99,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
constexpr bool kThreePass = true;
|
||||
|
||||
using BlockWarps = ck_tile::sequence<2, 2>;
|
||||
using BlockTile = ck_tile::sequence<2, 128>;
|
||||
using BlockWarps = ck_tile::sequence<4, 1>;
|
||||
using BlockTile = ck_tile::sequence<4, 128>;
|
||||
using WarpTile = ck_tile::sequence<1, 64>;
|
||||
using Vector = ck_tile::sequence<1, 1>;
|
||||
|
||||
using Shape = ck_tile::AddRmsnorm2dRdquantShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
using Problem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
GammaDataType,
|
||||
|
||||
@@ -28,7 +28,6 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/,
|
||||
add_rmsnorm2d_rdquant_fwd_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
#if 1
|
||||
float r = -1;
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
@@ -128,9 +127,6 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/,
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, true, true>>(s, a);
|
||||
}
|
||||
return r;
|
||||
#else
|
||||
return add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 1, 2, 128, 8, true, true, false>>(s, a);
|
||||
#endif
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -139,7 +135,6 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
|
||||
float r = -1;
|
||||
// Only support instance of save_x == true for now
|
||||
assert(t.save_x);
|
||||
if(t.data_type.compare("fp16") == 0)
|
||||
@@ -150,8 +145,6 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t,
|
||||
{
|
||||
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::bf16_t>(t, a, s);
|
||||
}
|
||||
if(r < 0)
|
||||
else
|
||||
throw std::runtime_error("Without supported instances!");
|
||||
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
|
||||
# run from top of ck folder
|
||||
EXE=build/bin/tile_add_rmsnorm2d_rdquant_fwd
|
||||
#!/bin/sh
|
||||
EXE="$(find . -name tile_add_rmsnorm2d_rdquant_fwd -type f | head -n 1)"
|
||||
|
||||
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#!/bin/sh
|
||||
# call from top of CK folder
|
||||
EXE=./build/bin/tile_add_rmsnorm2d_rdquant_fwd
|
||||
EXE="$(find . -name tile_add_rmsnorm2d_rdquant_fwd -type f | head -n 1)"
|
||||
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
$EXE -prec=$pr_i -m=99 -n=13
|
||||
|
||||
Reference in New Issue
Block a user