mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
fused complete
This commit is contained in:
@@ -520,57 +520,58 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
fused_add_list = [0, 1]
|
||||
fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant
|
||||
# rm rn tm tn vn pd mv fdiv welford 2p xbias add sweep
|
||||
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'5120' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 320, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 320, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1280, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1,1024, 8, True, False, True, True, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 12, 1, 256, 2, True, False, True, True, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0, 0)]}
|
||||
# h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
# '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
# '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
# '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
# '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
# '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
# '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
# '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
# '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
# '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
h_trait_dict = {'5120' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 320, 4, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 320, 4, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 640, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1280, 1, True, False, True, True, False, 0, 0, 0)]}
|
||||
# '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
# '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
# 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1,1024, 8, True, False, True, True, True, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 12, 1, 256, 2, True, False, True, True, True, 0, 0, 0),
|
||||
# h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0, 0)]}
|
||||
total_blob = list()
|
||||
for hs_key in h_trait_dict:
|
||||
hs = h_trait_dict[hs_key]
|
||||
|
||||
@@ -7,6 +7,10 @@ namespace ck_tile {
|
||||
|
||||
struct null_tensor
|
||||
{
|
||||
CK_TILE_HOST_DEVICE static constexpr auto is_valid()
|
||||
{
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -48,6 +48,11 @@ struct static_distributed_tensor
|
||||
return StaticTileDistribution{};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto is_valid()
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_distributed_spans()
|
||||
{
|
||||
return StaticTileDistribution::get_distributed_spans();
|
||||
|
||||
@@ -88,6 +88,23 @@ struct DynamicQuantEpilogue
|
||||
sequence<0, 1, 1>,
|
||||
sequence<0, 0, 3>>{});
|
||||
#else
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding<
|
||||
// sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
|
||||
// tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
// tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
// tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
// sequence<1, 1>,
|
||||
// sequence<0, 3>>{});
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding<
|
||||
// sequence<>,
|
||||
// tuple<sequence<1>,
|
||||
// sequence<S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
// tuple<sequence<2>, sequence<2>>,
|
||||
// tuple<sequence<1>, sequence<2>>,
|
||||
// sequence<2, 2>,
|
||||
// sequence<0, 3>>{});
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
|
||||
@@ -96,6 +113,15 @@ struct DynamicQuantEpilogue
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 3>>{});
|
||||
// return make_static_tile_distribution(
|
||||
// tile_distribution_encoding<
|
||||
// sequence<1>,
|
||||
// tuple<sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
|
||||
// sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
// tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
// tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
// sequence<1, 1>,
|
||||
// sequence<0, 3>>{});
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -108,17 +134,14 @@ struct DynamicQuantEpilogue
|
||||
template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile>
|
||||
CK_TILE_DEVICE auto Impl(ODramWindowTmp& o_dram_window_tmp,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
auto reduce = GetBlockReduce2d();
|
||||
auto reduce_sync = GetBlockReduce2dSync();
|
||||
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
|
||||
|
||||
auto o_acc_tmp = o_acc_tile;
|
||||
|
||||
const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
|
||||
|
||||
auto row_absmax = [&]() {
|
||||
constexpr auto y_size_per_row =
|
||||
OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(
|
||||
@@ -133,39 +156,49 @@ struct DynamicQuantEpilogue
|
||||
: "v"(acc_), "v"(v_0_), "v"(v_1_));
|
||||
return rtn;
|
||||
};
|
||||
return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{});
|
||||
return reduce(o_acc_tile, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_absmax);
|
||||
return reduce(o_acc_tile, type_convert<AccDataType>(0), f_absmax);
|
||||
}
|
||||
}();
|
||||
reduce_sync(row_absmax, f_absmax);
|
||||
reduce_crosswarp_sync(row_absmax, smem, f_absmax);
|
||||
|
||||
// here y_scale is Acc TYpe, need convert to YScale type later
|
||||
auto max_scale = 1 / type_convert<AccDataType>(numeric<ODataType>::max());
|
||||
auto y_scale = tile_elementwise_in(
|
||||
[&](const auto& v_) {
|
||||
return v_ / type_convert<AccDataType>(numeric<ODataType>::max());
|
||||
return v_ * max_scale ;
|
||||
},
|
||||
row_absmax);
|
||||
|
||||
store_tile(y_scale_window, cast_tile<YScaleDataType>(y_scale));
|
||||
|
||||
sweep_tile(o_acc_tmp, [&](auto idx) {
|
||||
constexpr auto row_id = make_tuple(idx[number<0>{}]);
|
||||
o_acc_tmp(idx) = o_acc_tmp[idx] / y_scale(row_id);
|
||||
});
|
||||
if constexpr(y_scale.get_thread_buffer_size() == 1)
|
||||
{
|
||||
auto scale = 1 / y_scale.get_thread_buffer().get(0);
|
||||
sweep_tile(o_acc_tile, [&](auto idx) {
|
||||
o_acc_tile(idx) = o_acc_tile[idx] * scale ;
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
sweep_tile(o_acc_tile, [&](auto idx) {
|
||||
constexpr auto row_id = make_tuple(idx[number<0>{}]);
|
||||
o_acc_tile(idx) = o_acc_tile[idx] / y_scale(row_id);
|
||||
});
|
||||
}
|
||||
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
buffer_store_fence();
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,23 +213,21 @@ struct DynamicQuantEpilogue
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
const SmoothScaleWindow& sm_scale_window_,
|
||||
YScaleWindow& y_scale_window,
|
||||
const OAccTile& o_acc_tile,
|
||||
OAccTile& o_acc_tile,
|
||||
void* smem)
|
||||
{
|
||||
const auto sm_scale_window =
|
||||
make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution());
|
||||
// const auto sm_scale_window =
|
||||
// make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution());
|
||||
//
|
||||
// auto sm_scale = load_tile(sm_scale_window);
|
||||
//
|
||||
// sweep_tile(o_acc_tile, [&](auto idx) {
|
||||
// constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
// const auto xs_ = type_convert<AccDataType>(sm_scale[j_idx]);
|
||||
// o_acc_tile(idx) = o_acc_tile(idx) * xs_;
|
||||
// });
|
||||
|
||||
auto sm_scale = load_tile(sm_scale_window);
|
||||
|
||||
auto o_acc_tmp = o_acc_tile;
|
||||
|
||||
sweep_tile(o_acc_tmp, [&](auto idx) {
|
||||
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
|
||||
const auto xs_ = type_convert<AccDataType>(sm_scale[j_idx]);
|
||||
o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
|
||||
});
|
||||
|
||||
Impl(o_dram_window_tmp, y_scale_window, o_acc_tmp, smem);
|
||||
Impl(o_dram_window_tmp, y_scale_window, o_acc_tile, smem);
|
||||
}
|
||||
|
||||
// Dynamic Quant
|
||||
@@ -208,5 +239,22 @@ struct DynamicQuantEpilogue
|
||||
{
|
||||
Impl(o_dram_window_tmp, y_scale_window, o_acc_tile, smem);
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmp, typename OAccTile>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) const
|
||||
{
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
buffer_store_fence();
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -367,7 +367,7 @@ struct Layernorm2dFwd
|
||||
|
||||
return pad_tensor_view(tmp_0_,
|
||||
make_tuple(number<Block_N>{}),
|
||||
sequence<false>{}); // sm_scale no need pad
|
||||
sequence<kPadN>{}); // sm_scale no need pad
|
||||
}();
|
||||
return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
|
||||
using mean_var_block_tile =
|
||||
decltype(block_welford::template MakeMeanVarBlockTile<x_block_tile>());
|
||||
|
||||
return GetBlockNormReduceCrossWarpSync<Problem>()
|
||||
return 2 * GetBlockNormReduceCrossWarpSync<Problem>()
|
||||
.template GetSmemSize<mean_var_block_tile>();
|
||||
}
|
||||
else
|
||||
|
||||
@@ -52,7 +52,20 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
return 2 * Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeSmoothInputScaleTileDistribution()
|
||||
{
|
||||
using S = Problem::BlockShape;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
|
||||
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>, sequence<1, 2>>,
|
||||
sequence<1, 1>,
|
||||
sequence<0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename XWindow,
|
||||
@@ -117,6 +130,9 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
const auto gamma = load_tile(gamma_window);
|
||||
const auto beta = load_tile(beta_window);
|
||||
|
||||
const auto sm_scale_window =
|
||||
make_tile_window(sm_scale_window_, MakeSmoothInputScaleTileDistribution());
|
||||
|
||||
auto acc = cast_tile<ComputeDataType>(x);
|
||||
|
||||
if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
|
||||
@@ -147,6 +163,9 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
block_norm_reduce(acc, mean, var, cur_count, max_count);
|
||||
block_norm_reduce_sync(mean, var, cur_count);
|
||||
block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
|
||||
|
||||
|
||||
auto sm_scale = load_tile(sm_scale_window);
|
||||
if(kWelford)
|
||||
{
|
||||
block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
|
||||
@@ -189,6 +208,11 @@ struct Layernorm2dFwdPipelineOnePass
|
||||
|
||||
auto ln_ = (acc[idx] - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
|
||||
ln(idx) = ln_;
|
||||
if constexpr(sm_scale.is_valid())
|
||||
{
|
||||
const auto xs_ = type_convert<ComputeDataType>(sm_scale[j_idx]);
|
||||
ln(idx) = ln(idx) * xs_;
|
||||
}
|
||||
});
|
||||
|
||||
if constexpr(kFusedQuant == Layernorm2dFusedQuantEnum::DYNAMIC_QUANT ||
|
||||
|
||||
@@ -46,7 +46,7 @@ struct BlockNormReduce
|
||||
constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0);
|
||||
|
||||
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
|
||||
if(kWelford)
|
||||
if constexpr(kWelford)
|
||||
{
|
||||
welford_update(mean_tensor(out_dstr_idx),
|
||||
var_tensor(out_dstr_idx),
|
||||
@@ -64,6 +64,42 @@ struct BlockNormReduce
|
||||
});
|
||||
}
|
||||
|
||||
template <typename XDistributedTensor_,
|
||||
typename MeanDistributedTensor_,
|
||||
typename VarDistributedTensor_,
|
||||
typename MinDistributedTensor_,
|
||||
typename MaxDistributedTensor_>
|
||||
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
|
||||
MeanDistributedTensor_& mean_tensor,
|
||||
VarDistributedTensor_& var_tensor,
|
||||
MinDistributedTensor_& min_tensor,
|
||||
MaxDistributedTensor_& max_tensor,
|
||||
int& cur_count_, // -> prefer init as zero
|
||||
const int& max_count_)
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
|
||||
constexpr auto spans = XDistributedTensor_::get_distributed_spans();
|
||||
|
||||
sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) {
|
||||
if(cur_count_ < max_count_)
|
||||
{
|
||||
++cur_count_;
|
||||
sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) {
|
||||
constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1);
|
||||
constexpr auto out_dstr_idx = make_tuple(dstr_idx_i0);
|
||||
|
||||
auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
|
||||
mean_tensor(out_dstr_idx) += x;
|
||||
var_tensor(out_dstr_idx) += x * x;
|
||||
min_tensor(out_dstr_idx) = ck_tile::min(x, min_tensor(out_dstr_idx));
|
||||
max_tensor(out_dstr_idx) = ck_tile::max(x, max_tensor(out_dstr_idx));
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename XDistributedTensor_>
|
||||
CK_TILE_DEVICE static auto MakeMeanVarBlockTile()
|
||||
{
|
||||
@@ -162,7 +198,7 @@ struct BlockNormReduceSync
|
||||
// pull data from remote lane
|
||||
const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
|
||||
const auto v_remote_var = warp_shuffle(v_local_var, src_lane);
|
||||
if(kWelford)
|
||||
if constexpr(kWelford)
|
||||
{
|
||||
const auto v_remote_count = warp_shuffle(v_local_count, src_lane);
|
||||
|
||||
@@ -192,6 +228,95 @@ struct BlockNormReduceSync
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename MeanDistributedTensor_,
|
||||
typename VarDistributedTensor_,
|
||||
typename MinDistributedTensor_,
|
||||
typename MaxDistributedTensor_>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(MeanDistributedTensor_& mean_tensor,
|
||||
VarDistributedTensor_& var_tensor,
|
||||
MinDistributedTensor_& min_tensor,
|
||||
MaxDistributedTensor_& max_tensor,
|
||||
int& count)
|
||||
{
|
||||
using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
using DstrEncodeDetail = typename DstrEncode::detail;
|
||||
|
||||
static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
|
||||
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
|
||||
// const auto ps_idx = make_array<index_t>(get_warp_id(), get_lane_id());
|
||||
// const auto rs_idx =
|
||||
// mean_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
|
||||
|
||||
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
|
||||
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
|
||||
|
||||
const int original_count = count;
|
||||
|
||||
// loop over thread data
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
auto v_local_mean = mean_tensor.get_thread_buffer()[i];
|
||||
auto v_local_var = var_tensor.get_thread_buffer()[i];
|
||||
auto v_local_min = min_tensor.get_thread_buffer()[i];
|
||||
auto v_local_max = max_tensor.get_thread_buffer()[i];
|
||||
auto v_local_count = original_count;
|
||||
|
||||
// cross-lane reduce for replication
|
||||
// only reduce on R dimension correspond to lane
|
||||
// (lane id maps to this R dimension)
|
||||
static_for<0, NDimR, 1>{}([&](auto idim_r) {
|
||||
// FIXME: nasty to use does_p_own_r_
|
||||
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
|
||||
{
|
||||
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
|
||||
|
||||
constexpr index_t lid_over_rid_derivative =
|
||||
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
|
||||
|
||||
static_assert(is_power_of_two_integer(r_length),
|
||||
"wrong! only support power of 2 reduction");
|
||||
|
||||
constexpr index_t nstage = integer_log2_floor(r_length);
|
||||
|
||||
// reduction sweep forward
|
||||
static_for<0, nstage, 1>{}([&](auto istage) {
|
||||
// xor
|
||||
index_t src_lane =
|
||||
(__lane_id()) ^
|
||||
(number<lid_over_rid_derivative << istage.value>{}.value);
|
||||
|
||||
// pull data from remote lane
|
||||
const auto v_remote_mean = warp_shuffle(v_local_mean, src_lane);
|
||||
const auto v_remote_var = warp_shuffle(v_local_var, src_lane);
|
||||
v_local_mean += v_remote_mean;
|
||||
v_local_var += v_remote_var;
|
||||
const auto v_remote_min = warp_shuffle(v_local_min, src_lane);
|
||||
const auto v_remote_max = warp_shuffle(v_local_max, src_lane);
|
||||
v_local_min = ck_tile::min(v_remote_min, v_local_min);
|
||||
v_local_max = ck_tile::max(v_remote_max, v_local_min);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
mean_tensor.get_thread_buffer()(i) = v_local_mean;
|
||||
var_tensor.get_thread_buffer()(i) = v_local_var;
|
||||
max_tensor.get_thread_buffer()(i) = v_local_min;
|
||||
min_tensor.get_thread_buffer()(i) = v_local_max;
|
||||
if constexpr(kWelford)
|
||||
{
|
||||
count = v_local_count;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
@@ -290,7 +415,7 @@ struct BlockNormReduceCrossWarpSync
|
||||
smem_dtype local_scratch_;
|
||||
local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]);
|
||||
local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]);
|
||||
if(kWelford)
|
||||
if constexpr(kWelford)
|
||||
{
|
||||
local_scratch_[2] = bit_cast<float>(count);
|
||||
}
|
||||
@@ -326,7 +451,7 @@ struct BlockNormReduceCrossWarpSync
|
||||
const smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
|
||||
const auto v_remote_mean = bit_cast<DataType>(v_remote[0]);
|
||||
const auto v_remote_var = bit_cast<DataType>(v_remote[1]);
|
||||
if(kWelford)
|
||||
if constexpr(kWelford)
|
||||
{
|
||||
const auto v_remote_count = bit_cast<int>(v_remote[2]);
|
||||
|
||||
@@ -347,10 +472,102 @@ struct BlockNormReduceCrossWarpSync
|
||||
|
||||
mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
|
||||
var_tensor.get_thread_buffer()(i_0) = v_local_var;
|
||||
if(kWelford)
|
||||
if constexpr(kWelford)
|
||||
count = v_local_count;
|
||||
});
|
||||
}
|
||||
|
||||
template <typename MeanDistributedTensor_,
|
||||
typename VarDistributedTensor_,
|
||||
typename MinDistributedTensor_,
|
||||
typename MaxDistributedTensor_>
|
||||
CK_TILE_DEVICE void operator()(MeanDistributedTensor_& mean_tensor,
|
||||
VarDistributedTensor_& var_tensor,
|
||||
MinDistributedTensor_& min_tensor,
|
||||
MaxDistributedTensor_& max_tensor,
|
||||
int& count,
|
||||
void* smem)
|
||||
{
|
||||
using DataType = typename MeanDistributedTensor_::DataType;
|
||||
using Dstr = typename MeanDistributedTensor_::StaticTileDistribution;
|
||||
// using DstrEncode = typename Dstr::DstrEncode;
|
||||
// using DstrEncodeDetail = typename DstrEncode::detail;
|
||||
|
||||
static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
|
||||
static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
|
||||
|
||||
using fused_smem_dtype = fp32x4_t;
|
||||
// Note: we always pack everything into fp32x4
|
||||
fused_smem_dtype* smem_ptr = reinterpret_cast<fused_smem_dtype*>(smem);
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
constexpr auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
const index_t smem_offset = warp_id;
|
||||
|
||||
// skip if nonthing to do
|
||||
if constexpr(num_reduce_warps == 1)
|
||||
return;
|
||||
|
||||
// store into smem only for lane-0 within one warp
|
||||
if(lane_id == 0)
|
||||
{
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
fused_smem_dtype local_scratch_;
|
||||
local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]);
|
||||
local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]);
|
||||
local_scratch_[2] = bit_cast<float>(min_tensor.get_thread_buffer()[i]);
|
||||
local_scratch_[3] = bit_cast<float>(max_tensor.get_thread_buffer()[i]);
|
||||
smem_ptr[smem_offset + i * num_warps] = local_scratch_;
|
||||
});
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// load from smem. here we let everythread to do compute :)
|
||||
index_t local_warp_id = warp_id / num_reduce_warps;
|
||||
index_t local_smem_os = local_warp_id * num_reduce_warps;
|
||||
fused_smem_dtype all_scratch[thread_buf_size * num_reduce_warps];
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
|
||||
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
|
||||
all_scratch[i_0 * num_reduce_warps + i_1] =
|
||||
smem_ptr[i_0 * num_warps + local_smem_os + i_1];
|
||||
});
|
||||
});
|
||||
block_sync_lds(); // TODO: we don't need sync here
|
||||
|
||||
// const int original_count = count;
|
||||
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
|
||||
// TODO: use descriptor for this
|
||||
auto v_local = all_scratch[i_0 * num_reduce_warps];
|
||||
auto v_local_mean = bit_cast<DataType>(v_local[0]);
|
||||
auto v_local_var = bit_cast<DataType>(v_local[1]);
|
||||
auto v_local_min = bit_cast<DataType>(v_local[2]);
|
||||
auto v_local_max = bit_cast<DataType>(v_local[3]);
|
||||
|
||||
// further reduce mean/var
|
||||
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
|
||||
constexpr auto i_1 = number<i_1_n1 + 1>{};
|
||||
const fused_smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
|
||||
const auto v_remote_mean = bit_cast<DataType>(v_remote[0]);
|
||||
const auto v_remote_var = bit_cast<DataType>(v_remote[1]);
|
||||
v_local_mean += v_remote_mean;
|
||||
v_local_var += v_remote_var;
|
||||
const auto v_remote_min = bit_cast<DataType>(v_remote[2]);
|
||||
const auto v_remote_max = bit_cast<DataType>(v_remote[3]);
|
||||
v_local_min = ck_tile::min(v_remote_min, v_local_min);
|
||||
v_local_max = ck_tile::max(v_remote_max, v_local_max);
|
||||
});
|
||||
|
||||
mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
|
||||
var_tensor.get_thread_buffer()(i_0) = v_local_var;
|
||||
min_tensor.get_thread_buffer()(i_0) = v_local_min;
|
||||
max_tensor.get_thread_buffer()(i_0) = v_local_max;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// compute the max count for a last dim reduce
|
||||
|
||||
Reference in New Issue
Block a user