mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
fix the clang-format (#2578)
This commit is contained in:
@@ -415,12 +415,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t N2_m = kNPerBlock / (N1_m * N0);
|
||||
constexpr auto dstr_m = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<N0, N1_m, N2_m>, sequence<K0_m, K1_m, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // N0, N1 K1
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<2, 1, 2>, // K0 N2 K2
|
||||
sequence<0, 2, 2>>{});
|
||||
sequence<>,
|
||||
tuple<sequence<N0, N1_m, N2_m>, sequence<K0_m, K1_m, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // N0, N1 K1
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<2, 1, 2>, // K0 N2 K2
|
||||
sequence<0, 2, 2>>{});
|
||||
static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kNPerBlock * kKPerBlock);
|
||||
return dstr_m;
|
||||
@@ -464,12 +464,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t N0_m = kNPerBlock / (N2_m * N1);
|
||||
constexpr auto dstr_m = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<N0_m, N1, N2_m>, sequence<K0_m, K1_m, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K1
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<2, 1, 2>, // K0 N0 K2
|
||||
sequence<0, 0, 2>>{});
|
||||
sequence<>,
|
||||
tuple<sequence<N0_m, N1, N2_m>, sequence<K0_m, K1_m, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K1
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<2, 1, 2>, // K0 N0 K2
|
||||
sequence<0, 0, 2>>{});
|
||||
static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kNPerBlock * kKPerBlock);
|
||||
return dstr_m;
|
||||
@@ -515,12 +515,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t M2_m = kMPerBlock / (M1_m * M0);
|
||||
constexpr auto dstr_m = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<M0, M1_m, M2_m>, sequence<K0_m, K1_m, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // M0, M1 K1
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<2, 1, 2>, // K0 M2 K2
|
||||
sequence<0, 2, 2>>{});
|
||||
sequence<>,
|
||||
tuple<sequence<M0, M1_m, M2_m>, sequence<K0_m, K1_m, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // M0, M1 K1
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<2, 1, 2>, // K0 M2 K2
|
||||
sequence<0, 2, 2>>{});
|
||||
static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock * kKPerBlock);
|
||||
return dstr_m;
|
||||
@@ -566,12 +566,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t M2_m = kMPerBlock / (M1_m * M0);
|
||||
constexpr auto dstr_m = make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<M0, M1_m, M2_m>, sequence<K0_m, K1_m, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // M0, M1 K1
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<2, 1, 2>, // K0 M2 K2
|
||||
sequence<0, 2, 2>>{});
|
||||
sequence<>,
|
||||
tuple<sequence<M0, M1_m, M2_m>, sequence<K0_m, K1_m, K2>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>, // M0, M1 K1
|
||||
tuple<sequence<0>, sequence<1, 1>>,
|
||||
sequence<2, 1, 2>, // K0 M2 K2
|
||||
sequence<0, 2, 2>>{});
|
||||
static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
|
||||
kMPerBlock * kKPerBlock);
|
||||
return dstr_m;
|
||||
|
||||
Reference in New Issue
Block a user