mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#5031 (commit 1d86a92)
[CK] Replace nested static_for with static_ford to reduce device IR function emissions [1B] (#5031) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary ### Rationale CK's GPU kernels are among the slowest files in the ROCm build, with a single translation unit taking up to 10+ minutes. Profiling with `-ftime-trace` identified nested `static_for` loops as the root cause: each nesting level multiplies the number of unique lambda IR functions the compiler must process. A 2-level nest of `static_for<0, M, 1>` / `static_for<0, N, 1>` produces M×N unique lambda types. With typical GEMM dimensions (M=16, N=4), a single nest generates 64 unique functions — and these nests appear hundreds of times across the codebase. The LLVM backend's CGSCC (Call Graph Strongly Connected Components) framework processes each function independently, so reducing function count directly reduces backend time. ### What changed 393 nested compile-time loop patterns across 73 files are converted to `static_ford`, which flattens multi-dimensional compile-time iteration into a single `static_for` with index decomposition. This eliminates 994 `static_for` nesting levels (42% reduction). Three pattern categories were converted: - **Category A**: `static_for` wrapping `static_ford` — fold outer dimension into ford - **Category B**: nested `static_ford` — merge into single higher-dimensional ford - **Category C**: nested `static_for` chains — convert to single `static_ford` ### Verification **ASM equivalence: PASS — 51/51 device assembly files identical (gfx942 + gfx1100)** | Architecture | Files compared | Largest file | Result | |---|---|---|---| | gfx942 | 36 | 386,685 lines | ALL MATCH | | gfx1100 | 15 | 47,769 lines | ALL MATCH | **Build time (Wilcoxon signed-rank test, 7 paired trials):** | Target | Pre (s) | Post (s) | Delta | p-value | |---|---|---|---|---| | bscale | 169 | 152 | **-9.8%** | 0.016 \* | | xdl_v1234 | 207 | 194 | **-6.6%** | 0.016 \* | | preshuffle | 275 | 264 | **-3.9%** | 0.016 \* | | xdl_base | 142 | 137 | **-3.2%** | 0.031 \* | **IR function counts (device backend, gfx942):** | Target | InstFunc Δ | CodeGen Δ | Compiler Δ | |---|---|---|---| | bscale | -13,043 (-8.2%) | -2,103 (-3.5%) | -10.7% | | xdl_v1234 | -9,431 (-5.7%) | +59 (+0.1%) | -5.2% | | xdl_base | -6,162 (-4.9%) | -1,141 (-2.5%) | -2.2% | | xdl_old | -3,234 (-3.7%) | -963 (-8.7%) | -3.3% | ### Value - **994 fewer `static_for` nesting levels** (-42%) across 73 files - **393 `static_ford` sites** created (from 4 pre-existing) - **Up to 9.8% compile-time reduction** on representative targets (statistically significant, p < 0.05) - **Up to 13K fewer IR function instantiations** per translation unit - Net -849 LOC from reduced indentation - **Zero ASM changes** — identical device code output verified on gfx942 and gfx1100 - All scheduling barriers, `if constexpr` guards, and MFMA/WMMA accumulation order preserved ### Files changed (73) - `block/`: 47 files (GEMM pipelines — xdlops, wmma, moe, preshuffle, blockscale variants) - `grid/`: 20 files (softmax, normalization, reduction, attention, layernorm) - `thread/`: 5 files (tensor slice transfer, contraction, GEMM dlops, reduction) - `tensor_description/`: 1 file (tensor_adaptor) ## Test plan - [x] `static_ford` tested with 21 unit tests in `test/util/unit_ford.cpp` (1D-4D, custom orders, compile-time verification) - [x] All conversions preserve iteration order, `block_sync_lds()` placement, `if constexpr` scheduling guards, and MFMA/WMMA accumulation order - [x] ASM equivalence verified: 51 device `.s` files across gfx942 + gfx1100 - [x] Build-time improvement statistically confirmed (Wilcoxon, p < 0.05, 4 targets) - [x] IR function count reduction confirmed via `-ftime-trace` on 7 targets - [x] Detection script reports 0 remaining safe patterns (180 blocked with structural reasons) - [x] Existing CI tests (GEMM, softmax, normalization, batch norm, reduction, attention) exercise all converted code paths ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
5f90f69795
commit
e5683e2290
@@ -35,14 +35,13 @@ struct ThreadwiseReduction
|
||||
template <typename SrcBufferType, typename DstBufferType>
|
||||
__device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf)
|
||||
{
|
||||
static_for<0, src_length_m, 1>{}([&](auto iM) {
|
||||
static_ford<Sequence<src_length_m, src_length_k>>{}([&](auto mk) {
|
||||
constexpr auto iM = Number<mk[Number<0>{}]>{};
|
||||
constexpr auto iK = Number<mk[Number<1>{}]>{};
|
||||
constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM));
|
||||
constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
static_for<0, src_length_k, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
Accumulation::Calculate(dst_buf(Number<out_offset>{}), src_buf[Number<offset>{}]);
|
||||
});
|
||||
Accumulation::Calculate(dst_buf(Number<out_offset>{}), src_buf[Number<offset>{}]);
|
||||
});
|
||||
};
|
||||
};
|
||||
@@ -81,17 +80,16 @@ struct ThreadwiseReductionWithIndex
|
||||
DstValueBufferType& dst_val_buf,
|
||||
DstIndexBufferType& dst_idx_buf)
|
||||
{
|
||||
static_for<0, src_length_m, 1>{}([&](auto iM) {
|
||||
static_ford<Sequence<src_length_m, src_length_k>>{}([&](auto mk) {
|
||||
constexpr auto iM = Number<mk[Number<0>{}]>{};
|
||||
constexpr auto iK = Number<mk[Number<1>{}]>{};
|
||||
constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM));
|
||||
constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
static_for<0, src_length_k, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
Accumulation::Calculate(dst_val_buf(Number<out_offset>{}),
|
||||
src_val_buf[Number<offset>{}],
|
||||
dst_idx_buf(Number<out_offset>{}),
|
||||
src_idx_buf[Number<offset>{}]);
|
||||
});
|
||||
Accumulation::Calculate(dst_val_buf(Number<out_offset>{}),
|
||||
src_val_buf[Number<offset>{}],
|
||||
dst_idx_buf(Number<out_offset>{}),
|
||||
src_idx_buf[Number<offset>{}]);
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
@@ -81,28 +81,21 @@ struct ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1
|
||||
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
|
||||
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
|
||||
|
||||
static_for<0, TK, 1>{}([&](auto tk) {
|
||||
static_for<0, TM0, 1>{}([&](auto tm0) {
|
||||
static_for<0, TM1, 1>{}([&](auto tm1) {
|
||||
static_for<0, TN0, 1>{}([&](auto tn0) {
|
||||
static_for<0, TN1, 1>{}([&](auto tn1) {
|
||||
constexpr index_t a_offset =
|
||||
AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
|
||||
a_origin_idx + make_multi_index(tk, tm0, tm1));
|
||||
constexpr index_t b_offset =
|
||||
BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
|
||||
b_origin_idx + make_multi_index(tk, tn0, tn1));
|
||||
constexpr index_t c_offset =
|
||||
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
|
||||
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
|
||||
static_ford<Sequence<TK, TM0, TM1, TN0, TN1>>{}([&](auto tkmn) {
|
||||
constexpr auto tk = Number<tkmn[Number<0>{}]>{};
|
||||
constexpr auto tm0 = Number<tkmn[Number<1>{}]>{};
|
||||
constexpr auto tm1 = Number<tkmn[Number<2>{}]>{};
|
||||
constexpr auto tn0 = Number<tkmn[Number<3>{}]>{};
|
||||
constexpr auto tn1 = Number<tkmn[Number<4>{}]>{};
|
||||
constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
|
||||
a_origin_idx + make_multi_index(tk, tm0, tm1));
|
||||
constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
|
||||
b_origin_idx + make_multi_index(tk, tn0, tn1));
|
||||
constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
|
||||
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
|
||||
|
||||
inner_product<FloatA, FloatB, FloatC>(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset>{}],
|
||||
c_buf(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
inner_product<FloatA, FloatB, FloatC>(
|
||||
a_buf[Number<a_offset>{}], b_buf[Number<b_offset>{}], c_buf(Number<c_offset>{}));
|
||||
});
|
||||
}
|
||||
};
|
||||
@@ -181,42 +174,35 @@ struct ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0
|
||||
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
|
||||
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
|
||||
|
||||
static_for<0, TK0, 1>{}([&](auto tk0) {
|
||||
static_for<0, TM0, 1>{}([&](auto tm0) {
|
||||
static_for<0, TM1, 1>{}([&](auto tm1) {
|
||||
static_for<0, TN0, 1>{}([&](auto tn0) {
|
||||
static_for<0, TN1, 1>{}([&](auto tn1) {
|
||||
vector_type<FloatA, TK1> a_vec;
|
||||
vector_type<FloatB, TK1> b_vec;
|
||||
static_ford<Sequence<TK0, TM0, TM1, TN0, TN1>>{}([&](auto tkmn) {
|
||||
constexpr auto tk0 = Number<tkmn[Number<0>{}]>{};
|
||||
constexpr auto tm0 = Number<tkmn[Number<1>{}]>{};
|
||||
constexpr auto tm1 = Number<tkmn[Number<2>{}]>{};
|
||||
constexpr auto tn0 = Number<tkmn[Number<3>{}]>{};
|
||||
constexpr auto tn1 = Number<tkmn[Number<4>{}]>{};
|
||||
vector_type<FloatA, TK1> a_vec;
|
||||
vector_type<FloatB, TK1> b_vec;
|
||||
|
||||
static_for<0, TK1, 1>{}([&](auto tk1) {
|
||||
constexpr index_t a_offset =
|
||||
AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
|
||||
a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1));
|
||||
static_for<0, TK1, 1>{}([&](auto tk1) {
|
||||
constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
|
||||
a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1));
|
||||
|
||||
constexpr index_t b_offset =
|
||||
BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
|
||||
b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1));
|
||||
constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
|
||||
b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1));
|
||||
|
||||
a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
|
||||
b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
|
||||
});
|
||||
|
||||
using a_vector_t = typename vector_type<FloatA, TK1>::type;
|
||||
using b_vector_t = typename vector_type<FloatB, TK1>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
|
||||
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
|
||||
|
||||
inner_product<a_vector_t, b_vector_t, FloatC>(
|
||||
a_vec.template AsType<a_vector_t>()[I0],
|
||||
b_vec.template AsType<b_vector_t>()[I0],
|
||||
c_buf(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
|
||||
b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
|
||||
});
|
||||
|
||||
using a_vector_t = typename vector_type<FloatA, TK1>::type;
|
||||
using b_vector_t = typename vector_type<FloatB, TK1>::type;
|
||||
|
||||
constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
|
||||
c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
|
||||
|
||||
inner_product<a_vector_t, b_vector_t, FloatC>(a_vec.template AsType<a_vector_t>()[I0],
|
||||
b_vec.template AsType<b_vector_t>()[I0],
|
||||
c_buf(Number<c_offset>{}));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -81,53 +81,49 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
|
||||
static_for<0, K, 1>{}([&](auto k) {
|
||||
static_for<0, Ho, SubHW>{}([&](auto h) {
|
||||
static_for<0, Wo, SubHW>{}([&](auto w) {
|
||||
static_for<0, E1, 1>{}([&](auto e1) {
|
||||
static_for<0, E2, 1>{}([&](auto e2) {
|
||||
constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
|
||||
a_origin_idx + make_tuple(e1, k, e2));
|
||||
static_ford<Sequence<E1, E2>>{}([&](auto ee) {
|
||||
constexpr auto e1 = Number<ee[Number<0>{}]>{};
|
||||
constexpr auto e2 = Number<ee[Number<1>{}]>{};
|
||||
constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
|
||||
a_origin_idx + make_tuple(e1, k, e2));
|
||||
|
||||
constexpr index_t b0_offset =
|
||||
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
|
||||
b_origin_idx + make_tuple(e1, 0, h, w, e2));
|
||||
constexpr index_t b0_offset =
|
||||
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
|
||||
b_origin_idx + make_tuple(e1, 0, h, w, e2));
|
||||
|
||||
constexpr index_t b1_offset =
|
||||
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
|
||||
b_origin_idx + make_tuple(e1, 0, h, w + 1, e2));
|
||||
constexpr index_t b1_offset =
|
||||
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
|
||||
b_origin_idx + make_tuple(e1, 0, h, w + 1, e2));
|
||||
|
||||
constexpr index_t b2_offset =
|
||||
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
|
||||
b_origin_idx + make_tuple(e1, 0, h + 1, w, e2));
|
||||
constexpr index_t b2_offset =
|
||||
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
|
||||
b_origin_idx + make_tuple(e1, 0, h + 1, w, e2));
|
||||
|
||||
constexpr index_t b3_offset =
|
||||
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
|
||||
b_origin_idx + make_tuple(e1, 0, h + 1, w + 1, e2));
|
||||
constexpr index_t b3_offset =
|
||||
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
|
||||
b_origin_idx + make_tuple(e1, 0, h + 1, w + 1, e2));
|
||||
|
||||
constexpr index_t c0_offset =
|
||||
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx +
|
||||
make_tuple(k, 0, h, w));
|
||||
constexpr index_t c0_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
|
||||
c_origin_idx + make_tuple(k, 0, h, w));
|
||||
|
||||
constexpr index_t c1_offset =
|
||||
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
|
||||
c_origin_idx + make_tuple(k, 0, h, w + 1));
|
||||
constexpr index_t c1_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
|
||||
c_origin_idx + make_tuple(k, 0, h, w + 1));
|
||||
|
||||
constexpr index_t c2_offset =
|
||||
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
|
||||
c_origin_idx + make_tuple(k, 0, h + 1, w));
|
||||
constexpr index_t c2_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
|
||||
c_origin_idx + make_tuple(k, 0, h + 1, w));
|
||||
|
||||
constexpr index_t c3_offset =
|
||||
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
|
||||
c_origin_idx + make_tuple(k, 0, h + 1, w + 1));
|
||||
constexpr index_t c3_offset = CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(
|
||||
c_origin_idx + make_tuple(k, 0, h + 1, w + 1));
|
||||
|
||||
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b0_offset>{}],
|
||||
b_buf[Number<b1_offset>{}],
|
||||
b_buf[Number<b2_offset>{}],
|
||||
b_buf[Number<b3_offset>{}],
|
||||
c_buf(Number<c0_offset>{}),
|
||||
c_buf(Number<c1_offset>{}),
|
||||
c_buf(Number<c2_offset>{}),
|
||||
c_buf(Number<c3_offset>{}));
|
||||
});
|
||||
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b0_offset>{}],
|
||||
b_buf[Number<b1_offset>{}],
|
||||
b_buf[Number<b2_offset>{}],
|
||||
b_buf[Number<b3_offset>{}],
|
||||
c_buf(Number<c0_offset>{}),
|
||||
c_buf(Number<c1_offset>{}),
|
||||
c_buf(Number<c2_offset>{}),
|
||||
c_buf(Number<c3_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -136,29 +132,24 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
|
||||
else
|
||||
{
|
||||
|
||||
static_for<0, K, 1>{}([&](auto k) {
|
||||
static_for<0, Ho, 1>{}([&](auto h) {
|
||||
static_for<0, Wo, 1>{}([&](auto w) {
|
||||
static_for<0, E1, 1>{}([&](auto e1) {
|
||||
static_for<0, E2, 1>{}([&](auto e2) {
|
||||
constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset(
|
||||
a_origin_idx + make_tuple(e1, k, e2));
|
||||
static_ford<Sequence<K, Ho, Wo, E1, E2>>{}([&](auto khwe) {
|
||||
constexpr auto k = Number<khwe[Number<0>{}]>{};
|
||||
constexpr auto h = Number<khwe[Number<1>{}]>{};
|
||||
constexpr auto w = Number<khwe[Number<2>{}]>{};
|
||||
constexpr auto e1 = Number<khwe[Number<3>{}]>{};
|
||||
constexpr auto e2 = Number<khwe[Number<4>{}]>{};
|
||||
constexpr index_t a_offset =
|
||||
AThreadDesc_E1_K_E2{}.CalculateOffset(a_origin_idx + make_tuple(e1, k, e2));
|
||||
|
||||
constexpr index_t b_offset =
|
||||
BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
|
||||
b_origin_idx + make_tuple(e1, 0, h, w, e2));
|
||||
constexpr index_t b_offset = BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset(
|
||||
b_origin_idx + make_tuple(e1, 0, h, w, e2));
|
||||
|
||||
constexpr index_t c_offset =
|
||||
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx +
|
||||
make_tuple(k, 0, h, w));
|
||||
constexpr index_t c_offset =
|
||||
CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + make_tuple(k, 0, h, w));
|
||||
|
||||
inner_product<FloatA, FloatB, FloatC>(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset>{}],
|
||||
c_buf(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
inner_product<FloatA, FloatB, FloatC>(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset>{}],
|
||||
c_buf(Number<c_offset>{}));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1827,25 +1827,24 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
static_ford<Sequence<num_access, DstScalarPerVector>>{}([&](auto access_idx) {
|
||||
constexpr auto idx_1d = Number<access_idx[Number<0>{}]>{};
|
||||
constexpr auto i = Number<access_idx[Number<1>{}]>{};
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
// copy data from src_buf into dst_vector
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
DstData v;
|
||||
DstData v;
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v, src_buf[Number<src_offset>{}]);
|
||||
// apply element-wise operation
|
||||
element_op_(v, src_buf[Number<src_offset>{}]);
|
||||
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset>{}) = v;
|
||||
});
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset>{}) = v;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -1933,57 +1932,56 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
static_ford<Sequence<num_access, DstScalarPerVector>>{}([&](auto access_idx) {
|
||||
constexpr auto idx_1d = Number<access_idx[Number<0>{}]>{};
|
||||
constexpr auto i = Number<access_idx[Number<1>{}]>{};
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
// copy data from src_buf into dst_vector
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
// src_desc error, non constexpr, caused by merge transform
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
// src_desc error, non constexpr, caused by merge transform
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md +
|
||||
i * dst_scalar_step_in_vector);
|
||||
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(dst_slice_origin_idx + idx_md +
|
||||
i * dst_scalar_step_in_vector);
|
||||
|
||||
SrcData v_this_row, v_theother_row;
|
||||
// int type temp value due to intrinsic requirement
|
||||
int temp = 0;
|
||||
SrcData v_this_row, v_theother_row;
|
||||
// int type temp value due to intrinsic requirement
|
||||
int temp = 0;
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v_this_row, src_buf[Number<src_offset>{}]);
|
||||
// apply element-wise operation
|
||||
element_op_(v_this_row, src_buf[Number<src_offset>{}]);
|
||||
|
||||
// apply intra-row permute.
|
||||
if constexpr(IntraRowSwizzlePerm)
|
||||
{
|
||||
temp = __builtin_amdgcn_permlane16(
|
||||
temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
|
||||
v_this_row = type_convert_sp<SrcData>(temp);
|
||||
}
|
||||
// apply intra-row permute.
|
||||
if constexpr(IntraRowSwizzlePerm)
|
||||
{
|
||||
temp = __builtin_amdgcn_permlane16(
|
||||
temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
|
||||
v_this_row = type_convert_sp<SrcData>(temp);
|
||||
}
|
||||
|
||||
// apply inter-row permute.
|
||||
temp = __builtin_amdgcn_permlanex16(temp,
|
||||
type_convert_sp<int>(v_this_row),
|
||||
LowEightRowlaneIdx,
|
||||
HighEightRowLaneIdx,
|
||||
1,
|
||||
0);
|
||||
v_theother_row = type_convert_sp<SrcData>(temp);
|
||||
// apply inter-row permute.
|
||||
temp = __builtin_amdgcn_permlanex16(temp,
|
||||
type_convert_sp<int>(v_this_row),
|
||||
LowEightRowlaneIdx,
|
||||
HighEightRowLaneIdx,
|
||||
1,
|
||||
0);
|
||||
v_theother_row = type_convert_sp<SrcData>(temp);
|
||||
|
||||
if(get_thread_local_1d_id() % 32 < 16)
|
||||
{
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_this_row);
|
||||
dst_buf(Number<dst_offset + DstScalarPerVector>{}) =
|
||||
type_convert_sp<DstData>(v_theother_row);
|
||||
}
|
||||
else
|
||||
{
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset + DstScalarPerVector>{}) =
|
||||
type_convert_sp<DstData>(v_this_row);
|
||||
dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_theother_row);
|
||||
}
|
||||
});
|
||||
if(get_thread_local_1d_id() % 32 < 16)
|
||||
{
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_this_row);
|
||||
dst_buf(Number<dst_offset + DstScalarPerVector>{}) =
|
||||
type_convert_sp<DstData>(v_theother_row);
|
||||
}
|
||||
else
|
||||
{
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset + DstScalarPerVector>{}) =
|
||||
type_convert_sp<DstData>(v_this_row);
|
||||
dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_theother_row);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
@@ -2060,36 +2058,35 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
static_ford<Sequence<num_access, DstScalarPerVector>>{}([&](auto access_idx) {
|
||||
constexpr auto idx_1d = Number<access_idx[Number<0>{}]>{};
|
||||
constexpr auto i = Number<access_idx[Number<1>{}]>{};
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
// copy data from src_buf into dst_vector
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
// src_desc error, non constexpr, caused by merge transform
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
// src_desc error, non constexpr, caused by merge transform
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(src_slice_origin_idx + idx_md +
|
||||
i * dst_scalar_step_in_vector);
|
||||
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(dst_slice_origin_idx + idx_md +
|
||||
i * dst_scalar_step_in_vector);
|
||||
|
||||
SrcData v_this_row;
|
||||
// int type temp value due to intrinsic requirement
|
||||
int temp = 0;
|
||||
SrcData v_this_row;
|
||||
// int type temp value due to intrinsic requirement
|
||||
int temp = 0;
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v_this_row, src_buf[Number<src_offset>{}]);
|
||||
// apply element-wise operation
|
||||
element_op_(v_this_row, src_buf[Number<src_offset>{}]);
|
||||
|
||||
// apply intra-row permute.
|
||||
if constexpr(IntraRowSwizzlePerm)
|
||||
{
|
||||
temp = __builtin_amdgcn_permlane16(
|
||||
temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
|
||||
v_this_row = type_convert_sp<SrcData>(temp);
|
||||
}
|
||||
// apply intra-row permute.
|
||||
if constexpr(IntraRowSwizzlePerm)
|
||||
{
|
||||
temp = __builtin_amdgcn_permlane16(
|
||||
temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
|
||||
v_this_row = type_convert_sp<SrcData>(temp);
|
||||
}
|
||||
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_this_row);
|
||||
});
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_this_row);
|
||||
});
|
||||
}
|
||||
ElementwiseOperation element_op_{};
|
||||
|
||||
Reference in New Issue
Block a user