From 72c98db0e076a404fe836b22c99178be16d7941b Mon Sep 17 00:00:00 2001 From: zanzhang Date: Fri, 22 Aug 2025 15:27:01 +0800 Subject: [PATCH] fused complete --- example/ck_tile/02_layernorm2d/generate.py | 103 ++++---- include/ck_tile/core/tensor/null_tensor.hpp | 4 + .../core/tensor/static_distributed_tensor.hpp | 5 + .../ops/epilogue/dynamic_quant_epilogue.hpp | 104 +++++--- .../kernel/layernorm2d_fwd_kernel.hpp | 2 +- ...ayernorm2d_fwd_pipeline_default_policy.hpp | 2 +- .../layernorm2d_fwd_pipeline_one_pass.hpp | 26 +- .../norm_reduce/block/block_norm_reduce.hpp | 227 +++++++++++++++++- 8 files changed, 386 insertions(+), 87 deletions(-) diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index 6fef66b89b..053105b05c 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -520,57 +520,58 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, fused_add_list = [0, 1] fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant # rm rn tm tn vn pd mv fdiv welford 2p xbias add sweep - h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], - '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], - '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], - '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], - '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], - '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], - '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], - '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], - '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], - '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], - '5120' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 320, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 320, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1280, 1, True, False, True, True, False, 0, 0, 0)], - '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], - '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1,1024, 8, True, False, True, True, True, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 12, 1, 256, 2, True, False, True, True, True, 0, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0, 0)]} + # h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + # '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + # '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + # '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + # '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0, 0)], + # '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], + # '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], + # '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0, 0)], + # '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], + # '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], + h_trait_dict = {'5120' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 320, 4, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 320, 4, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 640, 8, True, False, True, True, False, 0, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1280, 1, True, False, True, True, False, 0, 0, 0)]} + # '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], + # '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0, 0)], + # 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1,1024, 8, True, False, True, True, True, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 12, 1, 256, 2, True, False, True, True, True, 0, 0, 0), + # h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0, 0)]} total_blob = list() for hs_key in h_trait_dict: hs = h_trait_dict[hs_key] diff --git a/include/ck_tile/core/tensor/null_tensor.hpp b/include/ck_tile/core/tensor/null_tensor.hpp index 565ff87dff..f6f8665a41 100644 --- a/include/ck_tile/core/tensor/null_tensor.hpp +++ b/include/ck_tile/core/tensor/null_tensor.hpp @@ -7,6 +7,10 @@ namespace ck_tile { struct null_tensor { + CK_TILE_HOST_DEVICE static constexpr auto is_valid() + { + return false; + } }; } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index b73a27c8d5..77bfb165c7 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -48,6 +48,11 @@ struct static_distributed_tensor return StaticTileDistribution{}; } + CK_TILE_HOST_DEVICE static constexpr auto is_valid() + { + return true; + } + CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans() { return StaticTileDistribution::get_distributed_spans(); diff --git a/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp b/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp index c8168a1eed..fbc09a1659 100644 --- a/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp @@ -88,6 +88,23 @@ struct DynamicQuantEpilogue sequence<0, 1, 1>, sequence<0, 0, 3>>{}); #else + // return make_static_tile_distribution( + // tile_distribution_encoding< + // sequence, + // tuple>, + // tuple, sequence<0, 1>>, + // tuple, sequence<1, 2>>, + // sequence<1, 1>, + // sequence<0, 3>>{}); + // return make_static_tile_distribution( + // tile_distribution_encoding< + // sequence<>, + // tuple, + // sequence>, + // tuple, sequence<2>>, + // tuple, sequence<2>>, + // sequence<2, 2>, + // sequence<0, 3>>{}); return make_static_tile_distribution( tile_distribution_encoding< sequence, @@ -96,6 +113,15 @@ struct DynamicQuantEpilogue tuple, sequence<1, 2>>, sequence<1, 1>, sequence<0, 3>>{}); + // return make_static_tile_distribution( + // tile_distribution_encoding< + // sequence<1>, + // tuple, + // sequence>, + // tuple, sequence<0, 1>>, + // tuple, sequence<1, 2>>, + // sequence<1, 1>, + // sequence<0, 3>>{}); #endif } @@ -108,17 +134,14 @@ struct DynamicQuantEpilogue template CK_TILE_DEVICE auto Impl(ODramWindowTmp& o_dram_window_tmp, YScaleWindow& y_scale_window, - const OAccTile& o_acc_tile, + OAccTile& o_acc_tile, void* smem) { auto reduce = GetBlockReduce2d(); auto reduce_sync = GetBlockReduce2dSync(); auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); - auto o_acc_tmp = o_acc_tile; - const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); }; - auto row_absmax = [&]() { constexpr auto y_size_per_row = OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at( @@ -133,39 +156,49 @@ struct DynamicQuantEpilogue : "v"(acc_), "v"(v_0_), "v"(v_1_)); return rtn; }; - return reduce(o_acc_tmp, type_convert(0), f_max3, sequence<1, 2>{}); + return reduce(o_acc_tile, type_convert(0), f_max3, sequence<1, 2>{}); } else { - return reduce(o_acc_tmp, type_convert(0), f_absmax); + return reduce(o_acc_tile, type_convert(0), f_absmax); } }(); reduce_sync(row_absmax, f_absmax); reduce_crosswarp_sync(row_absmax, smem, f_absmax); // here y_scale is Acc TYpe, need convert to YScale type later + auto max_scale = 1 / type_convert(numeric::max()); auto y_scale = tile_elementwise_in( [&](const auto& v_) { - return v_ / type_convert(numeric::max()); + return v_ * max_scale ; }, row_absmax); store_tile(y_scale_window, cast_tile(y_scale)); - sweep_tile(o_acc_tmp, [&](auto idx) { - constexpr auto row_id = make_tuple(idx[number<0>{}]); - o_acc_tmp(idx) = o_acc_tmp[idx] / y_scale(row_id); - }); + if constexpr(y_scale.get_thread_buffer_size() == 1) + { + auto scale = 1 / y_scale.get_thread_buffer().get(0); + sweep_tile(o_acc_tile, [&](auto idx) { + o_acc_tile(idx) = o_acc_tile[idx] * scale ; + }); + } + else + { + sweep_tile(o_acc_tile, [&](auto idx) { + constexpr auto row_id = make_tuple(idx[number<0>{}]); + o_acc_tile(idx) = o_acc_tile[idx] / y_scale(row_id); + }); + } - // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) { - store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tmp)); + store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); buffer_store_fence(); } else { - store_tile(o_dram_window_tmp, cast_tile(o_acc_tmp)); + store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); } } @@ -180,23 +213,21 @@ struct DynamicQuantEpilogue CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const SmoothScaleWindow& sm_scale_window_, YScaleWindow& y_scale_window, - const OAccTile& o_acc_tile, + OAccTile& o_acc_tile, void* smem) { - const auto sm_scale_window = - make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution()); + // const auto sm_scale_window = + // make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution()); + // + // auto sm_scale = load_tile(sm_scale_window); + // + // sweep_tile(o_acc_tile, [&](auto idx) { + // constexpr auto j_idx = make_tuple(idx[number<1>{}]); + // const auto xs_ = type_convert(sm_scale[j_idx]); + // o_acc_tile(idx) = o_acc_tile(idx) * xs_; + // }); - auto sm_scale = load_tile(sm_scale_window); - - auto o_acc_tmp = o_acc_tile; - - sweep_tile(o_acc_tmp, [&](auto idx) { - constexpr auto j_idx = make_tuple(idx[number<1>{}]); - const auto xs_ = type_convert(sm_scale[j_idx]); - o_acc_tmp(idx) = o_acc_tmp(idx) * xs_; - }); - - Impl(o_dram_window_tmp, y_scale_window, o_acc_tmp, smem); + Impl(o_dram_window_tmp, y_scale_window, o_acc_tile, smem); } // Dynamic Quant @@ -208,5 +239,22 @@ struct DynamicQuantEpilogue { Impl(o_dram_window_tmp, y_scale_window, o_acc_tile, smem); } + + template + CK_TILE_DEVICE auto + operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) const + { + // TODO: this is ugly + if constexpr(UseRawStore && (kPadM || kPadN)) + { + store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); + buffer_store_fence(); + } + else + { + store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); + } + } + }; } // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index 146ac40fb7..885aac3427 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -367,7 +367,7 @@ struct Layernorm2dFwd return pad_tensor_view(tmp_0_, make_tuple(number{}), - sequence{}); // sm_scale no need pad + sequence{}); // sm_scale no need pad }(); return make_tile_window(win_, make_tuple(number{}), {0}); } diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp index 37f87b4fe0..e0989d91ce 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp @@ -95,7 +95,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy using mean_var_block_tile = decltype(block_welford::template MakeMeanVarBlockTile()); - return GetBlockNormReduceCrossWarpSync() + return 2 * GetBlockNormReduceCrossWarpSync() .template GetSmemSize(); } else diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp index 6c716d9682..bb677f61b6 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp @@ -52,7 +52,20 @@ struct Layernorm2dFwdPipelineOnePass CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { - return Policy::template GetSmemSize(); + return 2 * Policy::template GetSmemSize(); + } + + CK_TILE_DEVICE static constexpr auto MakeSmoothInputScaleTileDistribution() + { + using S = Problem::BlockShape; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 2>>, + sequence<1, 1>, + sequence<0, 3>>{}); } template (x); if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS) @@ -147,6 +163,9 @@ struct Layernorm2dFwdPipelineOnePass block_norm_reduce(acc, mean, var, cur_count, max_count); block_norm_reduce_sync(mean, var, cur_count); block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem); + + + auto sm_scale = load_tile(sm_scale_window); if(kWelford) { block_tile_welford_post_scale_var(var, cur_count, constant{}); @@ -189,6 +208,11 @@ struct Layernorm2dFwdPipelineOnePass auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_; ln(idx) = ln_; + if constexpr(sm_scale.is_valid()) + { + const auto xs_ = type_convert(sm_scale[j_idx]); + ln(idx) = ln(idx) * xs_; + } }); if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT || diff --git a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp index 88da6be86e..51c0572c25 100644 --- a/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp @@ -46,7 +46,7 @@ struct BlockNormReduce constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0); auto x = ck_tile::type_convert(x_tensor[in_dstr_idx]); - if(kWelford) + if constexpr(kWelford) { welford_update(mean_tensor(out_dstr_idx), var_tensor(out_dstr_idx), @@ -64,6 +64,42 @@ struct BlockNormReduce }); } + template + CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor, + MeanDistributedTensor_& mean_tensor, + VarDistributedTensor_& var_tensor, + MinDistributedTensor_& min_tensor, + MaxDistributedTensor_& max_tensor, + int& cur_count_, // -> prefer init as zero + const int& max_count_) + { + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + constexpr auto spans = XDistributedTensor_::get_distributed_spans(); + + sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) { + if(cur_count_ < max_count_) + { + ++cur_count_; + sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) { + constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1); + constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0); + + auto x = ck_tile::type_convert(x_tensor[in_dstr_idx]); + mean_tensor(out_dstr_idx) += x; + var_tensor(out_dstr_idx) += x * x; + min_tensor(out_dstr_idx) = ck_tile::min(x, min_tensor(out_dstr_idx)); + max_tensor(out_dstr_idx) = ck_tile::max(x, max_tensor(out_dstr_idx)); + }); + } + }); + } + template CK_TILE_DEVICE static auto MakeMeanVarBlockTile() { @@ -162,7 +198,7 @@ struct BlockNormReduceSync // pull data from remote lane const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane); const auto v_remote_var = warp_shuffle(v_local_var, src_lane); - if(kWelford) + if constexpr(kWelford) { const auto v_remote_count = warp_shuffle(v_local_count, src_lane); @@ -192,6 +228,95 @@ struct BlockNormReduceSync } }); } + + template + CK_TILE_DEVICE void + operator()(MeanDistributedTensor_& mean_tensor, + VarDistributedTensor_& var_tensor, + MinDistributedTensor_& min_tensor, + MaxDistributedTensor_& max_tensor, + int& count) + { + using Dstr = typename MeanDistributedTensor_::StaticTileDistribution; + using DstrEncode = typename Dstr::DstrEncode; + using DstrEncodeDetail = typename DstrEncode::detail; + + static_assert(std::is_same_v, + "wrong!"); + + constexpr index_t NDimP = Dstr::get_num_of_dimension_p(); + constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); + + constexpr index_t idim_p_lane = NDimP - 1; + + // const auto ps_idx = make_array(get_warp_id(), get_lane_id()); + // const auto rs_idx = + // mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); + + constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size(); + static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()); + + const int original_count = count; + + // loop over thread data + static_for<0, thread_buf_size, 1>{}([&](auto i) { + auto v_local_mean = mean_tensor.get_thread_buffer()[i]; + auto v_local_var = var_tensor.get_thread_buffer()[i]; + auto v_local_min = min_tensor.get_thread_buffer()[i]; + auto v_local_max = max_tensor.get_thread_buffer()[i]; + auto v_local_count = original_count; + + // cross-lane reduce for replication + // only reduce on R dimension correspond to lane + // (lane id maps to this R dimension) + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r]; + + static_assert(is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = integer_log2_floor(r_length); + + // reduction sweep forward + static_for<0, nstage, 1>{}([&](auto istage) { + // xor + index_t src_lane = + (__lane_id()) ^ + (number{}.value); + + // pull data from remote lane + const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane); + const auto v_remote_var = warp_shuffle(v_local_var, src_lane); + v_local_mean += v_remote_mean; + v_local_var += v_remote_var; + const auto v_remote_min = warp_shuffle(v_local_min, src_lane); + const auto v_remote_max = warp_shuffle(v_local_max, src_lane); + v_local_min = ck_tile::min(v_remote_min, v_local_min); + v_local_max = ck_tile::max(v_remote_max, v_local_min); + }); + } + }); + + mean_tensor.get_thread_buffer()(i) = v_local_mean; + var_tensor.get_thread_buffer()(i) = v_local_var; + max_tensor.get_thread_buffer()(i) = v_local_min; + min_tensor.get_thread_buffer()(i) = v_local_max; + if constexpr(kWelford) + { + count = v_local_count; + } + }); + } + }; template @@ -290,7 +415,7 @@ struct BlockNormReduceCrossWarpSync smem_dtype local_scratch_; local_scratch_[0] = bit_cast(mean_tensor.get_thread_buffer()[i]); local_scratch_[1] = bit_cast(var_tensor.get_thread_buffer()[i]); - if(kWelford) + if constexpr(kWelford) { local_scratch_[2] = bit_cast(count); } @@ -326,7 +451,7 @@ struct BlockNormReduceCrossWarpSync const smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1]; const auto v_remote_mean = bit_cast(v_remote[0]); const auto v_remote_var = bit_cast(v_remote[1]); - if(kWelford) + if constexpr(kWelford) { const auto v_remote_count = bit_cast(v_remote[2]); @@ -347,10 +472,102 @@ struct BlockNormReduceCrossWarpSync mean_tensor.get_thread_buffer()(i_0) = v_local_mean; var_tensor.get_thread_buffer()(i_0) = v_local_var; - if(kWelford) + if constexpr(kWelford) count = v_local_count; }); } + + template + CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor, + VarDistributedTensor_& var_tensor, + MinDistributedTensor_& min_tensor, + MaxDistributedTensor_& max_tensor, + int& count, + void* smem) + { + using DataType = typename MeanDistributedTensor_::DataType; + using Dstr = typename MeanDistributedTensor_::StaticTileDistribution; + // using DstrEncode = typename Dstr::DstrEncode; + // using DstrEncodeDetail = typename DstrEncode::detail; + + static_assert(std::is_same_v, + "wrong!"); + + constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size(); + static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size()); + + using fused_smem_dtype = fp32x4_t; + // Note: we always pack everything into fp32x4 + fused_smem_dtype* smem_ptr = reinterpret_cast(smem); + const index_t lane_id = get_lane_id(); + const index_t warp_id = get_warp_id(); + constexpr auto num_reduce_warps = GetReduceWarps(); + constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size(); + const index_t smem_offset = warp_id; + + // skip if nonthing to do + if constexpr(num_reduce_warps == 1) + return; + + // store into smem only for lane-0 within one warp + if(lane_id == 0) + { + static_for<0, thread_buf_size, 1>{}([&](auto i) { + fused_smem_dtype local_scratch_; + local_scratch_[0] = bit_cast(mean_tensor.get_thread_buffer()[i]); + local_scratch_[1] = bit_cast(var_tensor.get_thread_buffer()[i]); + local_scratch_[2] = bit_cast(min_tensor.get_thread_buffer()[i]); + local_scratch_[3] = bit_cast(max_tensor.get_thread_buffer()[i]); + smem_ptr[smem_offset + i * num_warps] = local_scratch_; + }); + } + block_sync_lds(); + + // load from smem. here we let everythread to do compute :) + index_t local_warp_id = warp_id / num_reduce_warps; + index_t local_smem_os = local_warp_id * num_reduce_warps; + fused_smem_dtype all_scratch[thread_buf_size * num_reduce_warps]; + static_for<0, thread_buf_size, 1>{}([&](auto i_0) { + static_for<0, num_reduce_warps, 1>{}([&](auto i_1) { + all_scratch[i_0 * num_reduce_warps + i_1] = + smem_ptr[i_0 * num_warps + local_smem_os + i_1]; + }); + }); + block_sync_lds(); // TODO: we don't need sync here + + // const int original_count = count; + + static_for<0, thread_buf_size, 1>{}([&](auto i_0) { + // TODO: use descriptor for this + auto v_local = all_scratch[i_0 * num_reduce_warps]; + auto v_local_mean = bit_cast(v_local[0]); + auto v_local_var = bit_cast(v_local[1]); + auto v_local_min = bit_cast(v_local[2]); + auto v_local_max = bit_cast(v_local[3]); + + // further reduce mean/var + static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) { + constexpr auto i_1 = number{}; + const fused_smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1]; + const auto v_remote_mean = bit_cast(v_remote[0]); + const auto v_remote_var = bit_cast(v_remote[1]); + v_local_mean += v_remote_mean; + v_local_var += v_remote_var; + const auto v_remote_min = bit_cast(v_remote[2]); + const auto v_remote_max = bit_cast(v_remote[3]); + v_local_min = ck_tile::min(v_remote_min, v_local_min); + v_local_max = ck_tile::max(v_remote_max, v_local_max); + }); + + mean_tensor.get_thread_buffer()(i_0) = v_local_mean; + var_tensor.get_thread_buffer()(i_0) = v_local_var; + min_tensor.get_thread_buffer()(i_0) = v_local_min; + max_tensor.get_thread_buffer()(i_0) = v_local_max; + }); + } }; // compute the max count for a last dim reduce