diff --git a/example/ck_tile/02_layernorm2d/README.md b/example/ck_tile/02_layernorm2d/README.md index 14c6fc0d67..3573d70cd2 100644 --- a/example/ck_tile/02_layernorm2d/README.md +++ b/example/ck_tile/02_layernorm2d/README.md @@ -69,7 +69,7 @@ args: ``` ## limitations -Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, N>8192 case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. +Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose. ``` # some case @@ -82,4 +82,4 @@ Note that `fquant=2`, `fadd=2`, `prec_sx/prec_sy` other than `fp32` are not by d # standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant+fused-add-store, output in int8 ./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1 -fadd=1 -``` \ No newline at end of file +``` diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index 300f6c05e1..bf576db97e 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -202,8 +202,9 @@ float layernorm2d_fwd_(const S& s, A a) using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem; using Default2DEpilogue = ck_tile::Default2DEpilogue; - using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem>; + static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1; + using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem>; using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue; diff --git a/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp b/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp index 2e29604116..3dec404b4b 100644 --- a/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp @@ -8,17 +8,23 @@ namespace ck_tile { -template +template struct DynamicQuantEpilogueTraits { - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; - static constexpr bool UseRawStore = UseRawStore_; - static constexpr bool UseMax3 = UseMax3_; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool UseSmoothInputScale = UseSmoothInputScale_; + static constexpr bool UseRawStore = UseRawStore_; + static constexpr bool UseMax3 = UseMax3_; }; // this epilogue just store out a M*N matrix, row major template ; + using XScaleDataType = remove_cvref_t; using YScaleDataType = remove_cvref_t; using ODataType = remove_cvref_t; using BlockShape = remove_cvref_t; // can consum generic 2d shape using Traits = remove_cvref_t; }; +// TODO: we should put descriptor creation function into policy template struct DynamicQuantEpilogue { using Problem = remove_cvref_t; using AccDataType = remove_cvref_t; + using XScaleDataType = remove_cvref_t; using YScaleDataType = remove_cvref_t; using ODataType = remove_cvref_t; using BlockShape = remove_cvref_t; @@ -63,6 +72,33 @@ struct DynamicQuantEpilogue return BlockReduce2dCrossWarpSync{}; } + CK_TILE_DEVICE static constexpr auto MakeSmoothInputScaleTileDistribution() + { + using S = BlockShape; +#if 0 + // don't remove this + // Note that if we set encoding purposely like this, you will result in compile fail + // TODO: x_scale create local-scratch to accept arbitrary acc input (with same length) + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple>, + tuple, sequence<0, 1>>, + tuple, sequence<2, 2>>, + sequence<0, 1, 1>, + sequence<0, 0, 3>>{}); +#else + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 2>>, + sequence<1, 1>, + sequence<0, 3>>{}); +#endif + } + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); @@ -71,8 +107,12 @@ struct DynamicQuantEpilogue // TODO: this function assume store out vector size is the same as OAccTile last dimension size // how do we fix this ? - template + template CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, + const XScaleWindow& x_scale_window_, YScaleWindow& y_scale_window, const OAccTile& o_acc_tile, void* smem) @@ -80,6 +120,18 @@ struct DynamicQuantEpilogue auto reduce = GetBlockReduce2d(); auto reduce_sync = GetBlockReduce2dSync(); auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); + const auto x_scale_window = + make_tile_window(x_scale_window_, MakeSmoothInputScaleTileDistribution()); + + auto x_scale = load_tile(x_scale_window); + + auto o_acc_tmp = o_acc_tile; + + sweep_tile(o_acc_tmp, [&](auto idx) { + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + const auto xs_ = type_convert(x_scale[j_idx]); + o_acc_tmp(idx) = o_acc_tmp(idx) * xs_; + }); const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); }; @@ -87,10 +139,9 @@ struct DynamicQuantEpilogue constexpr auto y_size_per_row = OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at( number<1>{}); - // constexpr auto y_size_per_row = OAccTile::get_lengths()[number<1>{}]; if constexpr(UseMax3 && std::is_same_v && y_size_per_row % 2 == 0) { - // fast max3 implementation + // fast max3+abs implementation const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) { float rtn; asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)" @@ -98,11 +149,11 @@ struct DynamicQuantEpilogue : "v"(acc_), "v"(v_0_), "v"(v_1_)); return rtn; }; - return reduce(o_acc_tile, type_convert(0), f_max3, sequence<1, 2>{}); + return reduce(o_acc_tmp, type_convert(0), f_max3, sequence<1, 2>{}); } else { - return reduce(o_acc_tile, type_convert(0), f_absmax); + return reduce(o_acc_tmp, type_convert(0), f_absmax); } }(); reduce_sync(row_absmax, f_absmax); @@ -117,23 +168,20 @@ struct DynamicQuantEpilogue store_tile(y_scale_window, cast_tile(y_scale)); - auto o_acc_scaled_tile = - make_static_distributed_tensor(o_acc_tile.get_tile_distribution()); - - sweep_tile(o_acc_tile, [&](auto idx) { - constexpr auto row_id = make_tuple(idx[number<0>{}]); - o_acc_scaled_tile(idx) = o_acc_tile[idx] / y_scale(row_id); + sweep_tile(o_acc_tmp, [&](auto idx) { + constexpr auto row_id = make_tuple(idx[number<0>{}]); + o_acc_tmp(idx) = o_acc_tmp[idx] / y_scale(row_id); }); // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) { - store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_scaled_tile)); + store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tmp)); buffer_store_fence(); } else { - store_tile(o_dram_window_tmp, cast_tile(o_acc_scaled_tile)); + store_tile(o_dram_window_tmp, cast_tile(o_acc_tmp)); } } }; diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp index 02fd5f7b93..1de230c144 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp @@ -45,7 +45,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelford() { - using P_ = BlockWelfordProblem; @@ -55,7 +55,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordSync() { - using P_ = BlockWelfordProblem; @@ -65,7 +65,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetBlockWelfordCrossWarpSync() { - using P_ = BlockWelfordProblem; @@ -77,13 +77,13 @@ struct Layernorm2dFwdPipelineDefaultPolicy { if constexpr(Problem::kNeedCrossWarpSync) { - using P_ = BlockWelfordProblem; using block_welford = BlockWelford; using x_block_tile = - decltype(make_static_distributed_tensor( + decltype(make_static_distributed_tensor( MakeXBlockTileDistribution())); using mean_var_block_tile = decltype(block_welford::template MakeMeanVarBlockTile()); diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp index 5601f3a68c..83cdab428e 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp @@ -87,12 +87,9 @@ struct Layernorm2dFwdPipelineOnePass x_residual_window_, Policy::template MakeXBlockTileDistribution()); auto y_residual_window = make_tile_window( y_residual_window_, Policy::template MakeXBlockTileDistribution()); - const auto x_scale_window = make_tile_window( - x_scale_window_, Policy::template MakeGammaBetaBlockTileDistribution()); - auto x = load_tile(x_window); - auto x_resi = load_tile(x_residual_window); - auto x_scale = load_tile(x_scale_window); + auto x = load_tile(x_window); + auto x_resi = load_tile(x_residual_window); int cur_count = 0; int max_count = @@ -106,21 +103,21 @@ struct Layernorm2dFwdPipelineOnePass const auto gamma = load_tile(gamma_window); const auto beta = load_tile(beta_window); + auto acc = cast_tile(x); + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) { sweep_tile(x_resi, [&](auto idx) { // compute x = x_resi + x - auto re_ = type_convert(x_resi(idx)) + - type_convert(x(idx)); - x(idx) = type_convert(re_); + acc(idx) = type_convert(x_resi(idx)) + acc(idx); }); if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) - store_tile(y_residual_window, x); + store_tile(y_residual_window, cast_tile(acc)); } // compute welford each-thread->cross-lane->cross-warp - auto [mean, var] = block_welford(x, cur_count, max_count); + auto [mean, var] = block_welford(acc, cur_count, max_count); block_welford_sync(mean, var, cur_count); block_welford_cross_warp_sync(mean, var, cur_count, smem); block_tile_welford_post_scale_var(var, cur_count); @@ -138,7 +135,7 @@ struct Layernorm2dFwdPipelineOnePass store_tile(inv_std_window, cast_tile(inv_std)); // layernorm computation - auto ln = make_static_distributed_tensor(x.get_tile_distribution()); + auto ln = make_static_distributed_tensor(acc.get_tile_distribution()); sweep_tile(ln, [&, mean_ = mean](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); constexpr auto j_idx = make_tuple(idx[number<1>{}]); @@ -146,26 +143,15 @@ struct Layernorm2dFwdPipelineOnePass const auto gamma_ = type_convert(gamma[j_idx]); const auto beta_ = type_convert(beta[j_idx]); - const auto x_ = type_convert(x[idx]); - auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; + auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; ln(idx) = ln_; }); - if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) - { - // smooth-quant pre-scale, then run rowwise-quant - sweep_tile(ln, [&](auto idx) { - constexpr auto j_idx = make_tuple(idx[number<1>{}]); - const auto xs_ = type_convert(x_scale[j_idx]); - ln(idx) = ln(idx) * xs_; - }); - } - if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT || kFusedQuant == Layernorm2dFusedQuantEnum::SMOOTH_DYNAMIC_QUANT) { - Epilogue{}(y_window_, y_scale_window, ln, smem); + Epilogue{}(y_window_, x_scale_window_, y_scale_window, ln, smem); } else Epilogue{}(y_window_, ln); diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp index 48f66739da..fadf56dfd3 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp @@ -106,7 +106,7 @@ struct Layernorm2dFwdPipelineTwoPass auto block_welford_cross_warp_sync = Policy::template GetBlockWelfordCrossWarpSync(); - using XTensorType = decltype(load_tile(x_window)); + using XTensorType = decltype(cast_tile(load_tile(x_window))); auto mean = block_welford.template MakeMeanVarBlockTile(); auto var = block_welford.template MakeMeanVarBlockTile(); @@ -117,22 +117,22 @@ struct Layernorm2dFwdPipelineTwoPass move_tile_window(x_window, {0, Block_N}); move_tile_window(x_residual_window, {0, Block_N}); + auto acc = cast_tile(x); + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) { sweep_tile(x_resi, [&](auto idx) { // compute x = x_resi + x - auto re_ = type_convert(x_resi(idx)) + - type_convert(x(idx)); - x(idx) = type_convert(re_); + acc(idx) = type_convert(x_resi(idx)) + acc(idx); }); if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE) { - store_tile(y_residual_window, x); + store_tile(y_residual_window, cast_tile(acc)); move_tile_window(y_residual_window, {0, Block_N}); } } - block_welford(x, mean, var, cur_count, max_count); + block_welford(acc, mean, var, cur_count, max_count); } block_welford_sync(mean, var, cur_count); @@ -166,21 +166,21 @@ struct Layernorm2dFwdPipelineTwoPass { auto x = load_tile(x_window); auto x_resi = load_tile(x_residual_window); + auto acc = cast_tile(x); + if constexpr(kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD_STORE || kFusedAdd == Layernorm2dFusedAddEnum::PRE_ADD) { sweep_tile(x_resi, [&](auto idx) { // compute x = x_resi + x - auto re_ = type_convert(x_resi(idx)) + - type_convert(x(idx)); - x(idx) = type_convert(re_); + acc(idx) = type_convert(x_resi(idx)) + acc(idx); }); } // load gamma/beta (TODO: support no gamma/beta?) const auto gamma = load_tile(gamma_window); const auto beta = load_tile(beta_window); - auto ln = make_static_distributed_tensor(x.get_tile_distribution()); + auto ln = make_static_distributed_tensor(acc.get_tile_distribution()); sweep_tile(ln, [&, mean_ = mean](auto idx) { constexpr auto i_idx = make_tuple(idx[number<0>{}]); @@ -189,8 +189,7 @@ struct Layernorm2dFwdPipelineTwoPass const auto gamma_ = type_convert(gamma[j_idx]); const auto beta_ = type_convert(beta[j_idx]); - const auto x_ = type_convert(x[idx]); - auto ln_ = (x_ - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; + auto ln_ = (acc(idx) - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; ln(idx) = ln_; });