mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[rocm-libraries] ROCm/rocm-libraries#5031 (commit 1d86a92)
[CK] Replace nested static_for with static_ford to reduce device IR function emissions [1B] (#5031) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary ### Rationale CK's GPU kernels are among the slowest files in the ROCm build, with a single translation unit taking up to 10+ minutes. Profiling with `-ftime-trace` identified nested `static_for` loops as the root cause: each nesting level multiplies the number of unique lambda IR functions the compiler must process. A 2-level nest of `static_for<0, M, 1>` / `static_for<0, N, 1>` produces M×N unique lambda types. With typical GEMM dimensions (M=16, N=4), a single nest generates 64 unique functions — and these nests appear hundreds of times across the codebase. The LLVM backend's CGSCC (Call Graph Strongly Connected Components) framework processes each function independently, so reducing function count directly reduces backend time. ### What changed 393 nested compile-time loop patterns across 73 files are converted to `static_ford`, which flattens multi-dimensional compile-time iteration into a single `static_for` with index decomposition. This eliminates 994 `static_for` nesting levels (42% reduction). Three pattern categories were converted: - **Category A**: `static_for` wrapping `static_ford` — fold outer dimension into ford - **Category B**: nested `static_ford` — merge into single higher-dimensional ford - **Category C**: nested `static_for` chains — convert to single `static_ford` ### Verification **ASM equivalence: PASS — 51/51 device assembly files identical (gfx942 + gfx1100)** | Architecture | Files compared | Largest file | Result | |---|---|---|---| | gfx942 | 36 | 386,685 lines | ALL MATCH | | gfx1100 | 15 | 47,769 lines | ALL MATCH | **Build time (Wilcoxon signed-rank test, 7 paired trials):** | Target | Pre (s) | Post (s) | Delta | p-value | |---|---|---|---|---| | bscale | 169 | 152 | **-9.8%** | 0.016 \* | | xdl_v1234 | 207 | 194 | **-6.6%** | 0.016 \* | | preshuffle | 275 | 264 | **-3.9%** | 0.016 \* | | xdl_base | 142 | 137 | **-3.2%** | 0.031 \* | **IR function counts (device backend, gfx942):** | Target | InstFunc Δ | CodeGen Δ | Compiler Δ | |---|---|---|---| | bscale | -13,043 (-8.2%) | -2,103 (-3.5%) | -10.7% | | xdl_v1234 | -9,431 (-5.7%) | +59 (+0.1%) | -5.2% | | xdl_base | -6,162 (-4.9%) | -1,141 (-2.5%) | -2.2% | | xdl_old | -3,234 (-3.7%) | -963 (-8.7%) | -3.3% | ### Value - **994 fewer `static_for` nesting levels** (-42%) across 73 files - **393 `static_ford` sites** created (from 4 pre-existing) - **Up to 9.8% compile-time reduction** on representative targets (statistically significant, p < 0.05) - **Up to 13K fewer IR function instantiations** per translation unit - Net -849 LOC from reduced indentation - **Zero ASM changes** — identical device code output verified on gfx942 and gfx1100 - All scheduling barriers, `if constexpr` guards, and MFMA/WMMA accumulation order preserved ### Files changed (73) - `block/`: 47 files (GEMM pipelines — xdlops, wmma, moe, preshuffle, blockscale variants) - `grid/`: 20 files (softmax, normalization, reduction, attention, layernorm) - `thread/`: 5 files (tensor slice transfer, contraction, GEMM dlops, reduction) - `tensor_description/`: 1 file (tensor_adaptor) ## Test plan - [x] `static_ford` tested with 21 unit tests in `test/util/unit_ford.cpp` (1D-4D, custom orders, compile-time verification) - [x] All conversions preserve iteration order, `block_sync_lds()` placement, `if constexpr` scheduling guards, and MFMA/WMMA accumulation order - [x] ASM equivalence verified: 51 device `.s` files across gfx942 + gfx1100 - [x] Build-time improvement statistically confirmed (Wilcoxon, p < 0.05, 4 targets) - [x] IR function count reduction confirmed via `-ftime-trace` on 7 targets - [x] Detection script reports 0 remaining safe patterns (180 blocked with structural reasons) - [x] Existing CI tests (GEMM, softmax, normalization, batch norm, reduction, attention) exercise all converted code paths ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
5f90f69795
commit
e5683e2290
@@ -480,19 +480,17 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
|
||||
make_tuple(I0, I0),
|
||||
dy_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iK = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset = thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
|
||||
dy_thread_buf[Number<offset>{}]);
|
||||
dy_elementwise_op(dy_thread_buf(Number<offset>{}), dy_thread_buf[Number<offset>{}]);
|
||||
|
||||
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
|
||||
inv_var_thread_buf[iM];
|
||||
AccDataType norm_x =
|
||||
(x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) * inv_var_thread_buf[iM];
|
||||
|
||||
tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
|
||||
});
|
||||
tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(tmp1_thread_buf, reduce_dscale_thread_buf);
|
||||
|
||||
@@ -502,15 +502,15 @@ struct EpilogueReduceCShuffle
|
||||
[&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
|
||||
|
||||
// reduce in VGPR
|
||||
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
|
||||
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
|
||||
constexpr auto offset =
|
||||
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
|
||||
make_tuple(im, in))>{};
|
||||
static_ford<Sequence<mreduce_per_thread, nreduce_per_thread>>{}([&](auto ii) {
|
||||
constexpr auto im = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto in = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset =
|
||||
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
|
||||
make_tuple(im, in))>{};
|
||||
|
||||
reduce_in_element_op(c_reduce_thread_buf(offset),
|
||||
c_reduce_thread_buf(offset));
|
||||
});
|
||||
reduce_in_element_op(c_reduce_thread_buf(offset),
|
||||
c_reduce_thread_buf(offset));
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
|
||||
|
||||
@@ -362,11 +362,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d
|
||||
make_tuple(I0),
|
||||
gamma_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
|
||||
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
|
||||
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
|
||||
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) * gamma_thread_buf(n);
|
||||
});
|
||||
static_ford<Sequence<MThreadSliceSize, NThreadSliceSize>>{}([&](auto mn) {
|
||||
constexpr auto m = Number<mn[Number<0>{}]>{};
|
||||
constexpr auto n = Number<mn[Number<1>{}]>{};
|
||||
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
|
||||
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) * gamma_thread_buf(n);
|
||||
});
|
||||
|
||||
threadwise_beta_load_n.Run(beta_grid_desc_n,
|
||||
@@ -375,11 +375,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d
|
||||
make_tuple(I0),
|
||||
beta_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
|
||||
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
|
||||
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
|
||||
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) + beta_thread_buf(n);
|
||||
});
|
||||
static_ford<Sequence<MThreadSliceSize, NThreadSliceSize>>{}([&](auto mn) {
|
||||
constexpr auto m = Number<mn[Number<0>{}]>{};
|
||||
constexpr auto n = Number<mn[Number<1>{}]>{};
|
||||
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
|
||||
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) + beta_thread_buf(n);
|
||||
});
|
||||
|
||||
threadwise_h_store_m_n.Run(thread_buffer_desc_m_n,
|
||||
|
||||
@@ -218,14 +218,12 @@ struct GridwiseMultipleReduction_mk_to_m_multiblock
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, NumReduction, 1>{}([&](auto iR) {
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number<offset>{}),
|
||||
in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iK = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number<offset>{}),
|
||||
in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR));
|
||||
|
||||
@@ -173,14 +173,12 @@ struct GridwiseMultipleReduction_mk_to_m_threadwise
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, NumReduction, 1>{}([&](auto iR) {
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number<offset>{}),
|
||||
in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iK = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number<offset>{}),
|
||||
in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR));
|
||||
|
||||
@@ -212,13 +212,11 @@ struct GridwiseReduction_mk_to_m_multiblock
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op(in_thread_buf(Number<offset>{}),
|
||||
in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iK = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op(in_thread_buf(Number<offset>{}), in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
|
||||
|
||||
@@ -162,13 +162,11 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op(in_thread_buf(Number<offset>{}),
|
||||
in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iK = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op(in_thread_buf(Number<offset>{}), in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
|
||||
@@ -340,15 +338,13 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
make_tuple(I0, I0),
|
||||
in_thread_idx_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iK = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
|
||||
in_thread_val_buf(Number<offset>{}));
|
||||
});
|
||||
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
|
||||
in_thread_val_buf(Number<offset>{}));
|
||||
});
|
||||
|
||||
ThreadwiseReduceWithIndex::Reduce(
|
||||
@@ -371,17 +367,15 @@ struct GridwiseReduction_mk_to_m_threadwise
|
||||
make_tuple(I0, I0),
|
||||
in_thread_val_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iK = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
|
||||
in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
|
||||
|
||||
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
|
||||
in_thread_val_buf(Number<offset>{}));
|
||||
});
|
||||
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
|
||||
in_thread_val_buf(Number<offset>{}));
|
||||
});
|
||||
|
||||
ThreadwiseReduceWithIndex::Reduce(
|
||||
|
||||
@@ -160,13 +160,11 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d
|
||||
make_tuple(I0, I0),
|
||||
in_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op(in_thread_buf(Number<offset>{}),
|
||||
in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iK = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
in_elementwise_op(in_thread_buf(Number<offset>{}), in_thread_buf(Number<offset>{}));
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
|
||||
|
||||
@@ -1297,15 +1297,15 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
|
||||
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
|
||||
|
||||
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
|
||||
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
|
||||
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
|
||||
Acc1DataType acc1 = acc1_thread_buf[I]; // P*V
|
||||
Acc1DataType c = c_thread_buf[I]; // O
|
||||
Acc1DataType c_new = c + acc1; // Simply add results since we are no longer using softmax.
|
||||
static_ford<Sequence<c_thread_buf_slice_m, c_thread_buf_slice_n>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iN = Number<ii[Number<1>{}]>{};
|
||||
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
|
||||
Acc1DataType acc1 = acc1_thread_buf[I]; // P*V
|
||||
Acc1DataType c = c_thread_buf[I]; // O
|
||||
Acc1DataType c_new = c + acc1; // Simply add results since we are no longer using softmax.
|
||||
|
||||
c_thread_buf(I) = c_new; // O_new
|
||||
});
|
||||
c_thread_buf(I) = c_new; // O_new
|
||||
});
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc,
|
||||
|
||||
@@ -1234,19 +1234,19 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
|
||||
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
|
||||
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
|
||||
|
||||
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
|
||||
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
|
||||
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
|
||||
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
|
||||
FloatGemmAcc c = c_thread_buf[I]; // O
|
||||
FloatGemmAcc c_new =
|
||||
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
|
||||
math::exp(max[iM] - running_max_new[iM]) * acc1) /
|
||||
running_sum_new[iM]; // Formula by Dao et al.,
|
||||
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
|
||||
static_ford<Sequence<c_thread_buf_slice_m, c_thread_buf_slice_n>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iN = Number<ii[Number<1>{}]>{};
|
||||
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
|
||||
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
|
||||
FloatGemmAcc c = c_thread_buf[I]; // O
|
||||
FloatGemmAcc c_new =
|
||||
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
|
||||
math::exp(max[iM] - running_max_new[iM]) * acc1) /
|
||||
running_sum_new[iM]; // Formula by Dao et al.,
|
||||
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
|
||||
|
||||
c_thread_buf(I) = c_new; // O_new
|
||||
});
|
||||
c_thread_buf(I) = c_new; // O_new
|
||||
});
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1,
|
||||
|
||||
@@ -1410,18 +1410,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
|
||||
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
|
||||
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
|
||||
|
||||
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
|
||||
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
|
||||
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
|
||||
Acc1DataType acc1 = acc1_thread_buf[I]; // P*V
|
||||
Acc1DataType c = c_thread_buf[I]; // O
|
||||
Acc1DataType c_new =
|
||||
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
|
||||
math::exp(max[iM] - running_max_new[iM]) * acc1) /
|
||||
running_sum_new[iM];
|
||||
static_ford<Sequence<c_thread_buf_slice_m, c_thread_buf_slice_n>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iN = Number<ii[Number<1>{}]>{};
|
||||
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
|
||||
Acc1DataType acc1 = acc1_thread_buf[I]; // P*V
|
||||
Acc1DataType c = c_thread_buf[I]; // O
|
||||
Acc1DataType c_new =
|
||||
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
|
||||
math::exp(max[iM] - running_max_new[iM]) * acc1) /
|
||||
running_sum_new[iM];
|
||||
|
||||
c_thread_buf(I) = c_new; // O_new
|
||||
});
|
||||
c_thread_buf(I) = c_new; // O_new
|
||||
});
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc,
|
||||
|
||||
@@ -1048,19 +1048,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
|
||||
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
|
||||
|
||||
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
|
||||
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
|
||||
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
|
||||
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
|
||||
FloatGemmAcc c = c_thread_buf[I]; // O
|
||||
FloatGemmAcc c_new =
|
||||
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
|
||||
math::exp(max[iM] - running_max_new[iM]) * acc1) /
|
||||
running_sum_new[iM]; // Formula by Dao et al.,
|
||||
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
|
||||
static_ford<Sequence<c_thread_buf_slice_m, c_thread_buf_slice_n>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iN = Number<ii[Number<1>{}]>{};
|
||||
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
|
||||
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
|
||||
FloatGemmAcc c = c_thread_buf[I]; // O
|
||||
FloatGemmAcc c_new =
|
||||
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
|
||||
math::exp(max[iM] - running_max_new[iM]) * acc1) /
|
||||
running_sum_new[iM]; // Formula by Dao et al.,
|
||||
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
|
||||
|
||||
c_thread_buf(I) = c_new; // O_new
|
||||
});
|
||||
c_thread_buf(I) = c_new; // O_new
|
||||
});
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1,
|
||||
|
||||
@@ -437,19 +437,17 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
|
||||
make_tuple(I0, I0),
|
||||
dy_thread_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iK = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset = thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
|
||||
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
|
||||
dy_thread_buf[Number<offset>{}]);
|
||||
dy_elementwise_op(dy_thread_buf(Number<offset>{}), dy_thread_buf[Number<offset>{}]);
|
||||
|
||||
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
|
||||
inv_var_thread_buf[iM];
|
||||
AccDataType norm_x =
|
||||
(x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) * inv_var_thread_buf[iM];
|
||||
|
||||
tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
|
||||
});
|
||||
tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(tmp1_thread_buf, dscale_thread_buf);
|
||||
|
||||
@@ -355,8 +355,10 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { // input add loop
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
static_ford<Sequence<MThreadSliceSize, XSrcVectorSize>>{}(
|
||||
[&](auto mk) { // input add loop
|
||||
constexpr auto iM = Number<mk[Number<0>{}]>{};
|
||||
constexpr auto iK1 = Number<mk[Number<1>{}]>{};
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
@@ -376,7 +378,6 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
|
||||
|
||||
unpack2(x_elementwise_op, out_data_refs, in_data_refs);
|
||||
});
|
||||
});
|
||||
threadwise_welford.Run(x_thread_buf[iK0], mean_thread_buf, var_thread_buf);
|
||||
|
||||
if constexpr(!SweepOnce)
|
||||
@@ -435,21 +436,20 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
|
||||
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
static_ford<Sequence<XThreadBufferNumber, XSrcVectorSize>>{}([&](auto ii) {
|
||||
constexpr auto iK0 = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iK1 = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// normalize
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
|
||||
divisor;
|
||||
// normalize
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * divisor;
|
||||
|
||||
// gamma
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) *
|
||||
gamma_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
// gamma
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) *
|
||||
gamma_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -463,19 +463,19 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
static_ford<Sequence<MThreadSliceSize, XThreadBufferNumber, XSrcVectorSize>>{}(
|
||||
[&](auto mii) {
|
||||
constexpr auto iM = Number<mii[Number<0>{}]>{};
|
||||
constexpr auto iK0 = Number<mii[Number<1>{}]>{};
|
||||
constexpr auto iK1 = Number<mii[Number<2>{}]>{};
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) +
|
||||
beta_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
// beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) +
|
||||
beta_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, YThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_y_store.Run(thread_buffer_desc_m_k,
|
||||
|
||||
@@ -937,8 +937,10 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
[&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
|
||||
|
||||
// reduce in VGPR
|
||||
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
|
||||
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
|
||||
static_ford<Sequence<mreduce_per_thread, nreduce_per_thread>>{}(
|
||||
[&](auto ii) {
|
||||
constexpr auto im = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto in = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset =
|
||||
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
|
||||
make_tuple(im, in))>{};
|
||||
@@ -946,7 +948,6 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
reduce_in_element_op(c_reduce_thread_buf(offset),
|
||||
c_reduce_thread_buf(offset));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
|
||||
|
||||
|
||||
@@ -881,14 +881,14 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
ThreadReduceOperation::template GetIdentityValue<FloatReduceAcc>();
|
||||
static_for<0, mreduce_per_thread, 1>{}(
|
||||
[&](auto I) { r_thread_buf(I) = reduce_identityVal; });
|
||||
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
|
||||
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
|
||||
constexpr auto offset =
|
||||
Number<cde_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
|
||||
make_tuple(im, in))>{};
|
||||
static_ford<Sequence<mreduce_per_thread, nreduce_per_thread>>{}([&](auto ii) {
|
||||
constexpr auto im = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto in = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset =
|
||||
Number<cde_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
|
||||
make_tuple(im, in))>{};
|
||||
|
||||
qs_element_op[Ir](e_thread_buf(offset), e_thread_buf(offset));
|
||||
});
|
||||
qs_element_op[Ir](e_thread_buf(offset), e_thread_buf(offset));
|
||||
});
|
||||
ThreadwiseReduce::Reduce(e_thread_buf, r_thread_buf);
|
||||
|
||||
|
||||
@@ -820,8 +820,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
[&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
|
||||
|
||||
// reduce in VGPR
|
||||
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
|
||||
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
|
||||
static_ford<Sequence<mreduce_per_thread, nreduce_per_thread>>{}(
|
||||
[&](auto ii) {
|
||||
constexpr auto im = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto in = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset =
|
||||
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
|
||||
make_tuple(im, in))>{};
|
||||
@@ -829,7 +831,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
reduce_in_element_op(c_reduce_thread_buf(offset),
|
||||
c_reduce_thread_buf(offset));
|
||||
});
|
||||
});
|
||||
|
||||
ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);
|
||||
|
||||
|
||||
@@ -997,26 +997,26 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0]
|
||||
.GetUpperLengths()[I1]; // TODO: proper handle
|
||||
|
||||
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
|
||||
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
|
||||
constexpr auto dst_offset =
|
||||
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
|
||||
make_tuple(im, in))>{};
|
||||
static_ford<Sequence<mreduce_per_thread, nreduce_per_thread>>{}([&](auto ii) {
|
||||
constexpr auto im = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto in = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto dst_offset =
|
||||
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
|
||||
make_tuple(im, in))>{};
|
||||
|
||||
constexpr auto src_offset =
|
||||
Number<d_reduce_thread_desc_mperblock.CalculateOffset(
|
||||
make_tuple(im))>{};
|
||||
constexpr auto src_offset =
|
||||
Number<d_reduce_thread_desc_mperblock.CalculateOffset(
|
||||
make_tuple(im))>{};
|
||||
|
||||
FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw;
|
||||
FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw;
|
||||
FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw;
|
||||
FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw;
|
||||
|
||||
FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum;
|
||||
FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum;
|
||||
FloatReduceAcc divisor_sqrt;
|
||||
tensor_operation::element_wise::UnarySqrt{}(divisor_sqrt, divisor);
|
||||
FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum;
|
||||
FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum;
|
||||
FloatReduceAcc divisor_sqrt;
|
||||
tensor_operation::element_wise::UnarySqrt{}(divisor_sqrt, divisor);
|
||||
|
||||
c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt;
|
||||
});
|
||||
c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt;
|
||||
});
|
||||
|
||||
// scaling
|
||||
|
||||
@@ -290,12 +290,12 @@ struct GridwiseSoftmax_mk_to_mk
|
||||
}
|
||||
|
||||
// do element-wise pre-reduction operation
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
out_thread_buf(Number<offset>{}) =
|
||||
math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM));
|
||||
});
|
||||
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iK = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
out_thread_buf(Number<offset>{}) =
|
||||
math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM));
|
||||
});
|
||||
|
||||
ThreadwiseSumReduce::Reduce(out_thread_buf, accu_value_buf);
|
||||
@@ -330,15 +330,13 @@ struct GridwiseSoftmax_mk_to_mk
|
||||
in_thread_buf);
|
||||
}
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
out_thread_buf(Number<offset>{}) =
|
||||
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
|
||||
accu_value_buf(iM);
|
||||
});
|
||||
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iK = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
out_thread_buf(Number<offset>{}) =
|
||||
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
|
||||
accu_value_buf(iM);
|
||||
});
|
||||
|
||||
threadwise_dst_store.Run(thread_buffer_desc,
|
||||
@@ -376,16 +374,14 @@ struct GridwiseSoftmax_mk_to_mk
|
||||
make_tuple(I0, I0),
|
||||
in_prior_dst_buf);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
|
||||
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset =
|
||||
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
out_thread_buf(Number<offset>{}) =
|
||||
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
|
||||
accu_value_buf(iM) +
|
||||
beta * in_prior_dst_buf(Number<offset>{});
|
||||
});
|
||||
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iK = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
|
||||
out_thread_buf(Number<offset>{}) =
|
||||
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
|
||||
accu_value_buf(iM) +
|
||||
beta * in_prior_dst_buf(Number<offset>{});
|
||||
});
|
||||
|
||||
threadwise_dst_store.Run(thread_buffer_desc,
|
||||
|
||||
@@ -175,36 +175,36 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
|
||||
};
|
||||
|
||||
auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
|
||||
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
|
||||
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
|
||||
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
|
||||
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
|
||||
auto in_data_refs = generate_tie(
|
||||
[&](auto i_embedding_) -> const auto& {
|
||||
return in_thread_bufs(i_embedding_)(Number<register_offset>{});
|
||||
},
|
||||
Number<NumEmbeddings>{});
|
||||
auto out_data_refs = generate_tie(
|
||||
[&](auto) -> auto& { return acc_thread_buf(Number<register_offset>{}); },
|
||||
Number<1>{});
|
||||
unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
|
||||
});
|
||||
static_ford<Sequence<DimThreadSize, RowVectorSize>>{}([&](auto ii) {
|
||||
constexpr auto i_dim_vec_ = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto i_row_vec_ = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
|
||||
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
|
||||
auto in_data_refs = generate_tie(
|
||||
[&](auto i_embedding_) -> const auto& {
|
||||
return in_thread_bufs(i_embedding_)(Number<register_offset>{});
|
||||
},
|
||||
Number<NumEmbeddings>{});
|
||||
auto out_data_refs = generate_tie(
|
||||
[&](auto) -> auto& { return acc_thread_buf(Number<register_offset>{}); },
|
||||
Number<1>{});
|
||||
unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
|
||||
});
|
||||
};
|
||||
|
||||
auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
|
||||
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
|
||||
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
|
||||
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
|
||||
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
|
||||
constexpr auto mean_var_offset =
|
||||
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
|
||||
static_ford<Sequence<DimThreadSize, RowVectorSize>>{}([&](auto ii) {
|
||||
constexpr auto i_dim_vec_ = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto i_row_vec_ = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
|
||||
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
|
||||
constexpr auto mean_var_offset =
|
||||
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
|
||||
|
||||
threadwise_welford.cur_count_++;
|
||||
threadwise_welford.Update(mean_thread_buf(Number<mean_var_offset>{}),
|
||||
var_thread_buf(Number<mean_var_offset>{}),
|
||||
acc_thread_buf(Number<register_offset>{}));
|
||||
});
|
||||
threadwise_welford.cur_count_++;
|
||||
threadwise_welford.Update(mean_thread_buf(Number<mean_var_offset>{}),
|
||||
var_thread_buf(Number<mean_var_offset>{}),
|
||||
acc_thread_buf(Number<register_offset>{}));
|
||||
});
|
||||
};
|
||||
|
||||
@@ -246,12 +246,11 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
|
||||
};
|
||||
|
||||
// first load index
|
||||
ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) {
|
||||
ck::static_ford<Sequence<DimPerBlock, NumEmbeddings>>{}([&](auto ie) {
|
||||
constexpr auto i_idx_ = Number<ie[Number<0>{}]>{};
|
||||
constexpr auto i_embedding_ = Number<ie[Number<1>{}]>{};
|
||||
// prefer use s_load
|
||||
ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
|
||||
index_bufs(i_embedding_)(i_idx_) =
|
||||
p_indexes[i_embedding_][index_start + i_idx_.value];
|
||||
});
|
||||
index_bufs(i_embedding_)(i_idx_) = p_indexes[i_embedding_][index_start + i_idx_.value];
|
||||
});
|
||||
|
||||
// load gamma/beta
|
||||
|
||||
@@ -176,36 +176,36 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
|
||||
};
|
||||
|
||||
auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
|
||||
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
|
||||
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
|
||||
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
|
||||
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
|
||||
auto in_data_refs = generate_tie(
|
||||
[&](auto i_embedding_) -> const auto& {
|
||||
return in_thread_bufs(i_embedding_)(Number<register_offset>{});
|
||||
},
|
||||
Number<NumEmbeddings>{});
|
||||
auto out_data_refs = generate_tie(
|
||||
[&](auto) -> auto& { return acc_thread_buf(Number<register_offset>{}); },
|
||||
Number<1>{});
|
||||
unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
|
||||
});
|
||||
static_ford<Sequence<DimThreadSize, RowVectorSize>>{}([&](auto ii) {
|
||||
constexpr auto i_dim_vec_ = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto i_row_vec_ = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
|
||||
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
|
||||
auto in_data_refs = generate_tie(
|
||||
[&](auto i_embedding_) -> const auto& {
|
||||
return in_thread_bufs(i_embedding_)(Number<register_offset>{});
|
||||
},
|
||||
Number<NumEmbeddings>{});
|
||||
auto out_data_refs = generate_tie(
|
||||
[&](auto) -> auto& { return acc_thread_buf(Number<register_offset>{}); },
|
||||
Number<1>{});
|
||||
unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
|
||||
});
|
||||
};
|
||||
|
||||
auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
|
||||
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
|
||||
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
|
||||
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
|
||||
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
|
||||
constexpr auto mean_var_offset =
|
||||
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
|
||||
static_ford<Sequence<DimThreadSize, RowVectorSize>>{}([&](auto ii) {
|
||||
constexpr auto i_dim_vec_ = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto i_row_vec_ = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
|
||||
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
|
||||
constexpr auto mean_var_offset =
|
||||
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
|
||||
|
||||
threadwise_welford.cur_count_++;
|
||||
threadwise_welford.Update(mean_thread_buf(Number<mean_var_offset>{}),
|
||||
var_thread_buf(Number<mean_var_offset>{}),
|
||||
acc_thread_buf(Number<register_offset>{}));
|
||||
});
|
||||
threadwise_welford.cur_count_++;
|
||||
threadwise_welford.Update(mean_thread_buf(Number<mean_var_offset>{}),
|
||||
var_thread_buf(Number<mean_var_offset>{}),
|
||||
acc_thread_buf(Number<register_offset>{}));
|
||||
});
|
||||
};
|
||||
|
||||
@@ -247,12 +247,11 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
|
||||
};
|
||||
|
||||
// first load index
|
||||
ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) {
|
||||
ck::static_ford<Sequence<DimPerBlock, NumEmbeddings>>{}([&](auto ie) {
|
||||
constexpr auto i_idx_ = Number<ie[Number<0>{}]>{};
|
||||
constexpr auto i_embedding_ = Number<ie[Number<1>{}]>{};
|
||||
// prefer use s_load
|
||||
ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
|
||||
index_bufs(i_embedding_)(i_idx_) =
|
||||
p_indexes[i_embedding_][index_start + i_idx_.value];
|
||||
});
|
||||
index_bufs(i_embedding_)(i_idx_) = p_indexes[i_embedding_][index_start + i_idx_.value];
|
||||
});
|
||||
|
||||
// load gamma/beta
|
||||
|
||||
@@ -328,14 +328,14 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
|
||||
make_tuple(I0, I0),
|
||||
gamma_thread_buf(i));
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
x_square_thread_buf(i)(Number<offset_m_k>{}) =
|
||||
x_thread_buf(i)(Number<offset_m_k>{}) *
|
||||
x_thread_buf(i)(Number<offset_m_k>{});
|
||||
});
|
||||
static_ford<Sequence<MThreadSliceSize, XSrcVectorSize>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iK = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
x_square_thread_buf(i)(Number<offset_m_k>{}) =
|
||||
x_thread_buf(i)(Number<offset_m_k>{}) *
|
||||
x_thread_buf(i)(Number<offset_m_k>{});
|
||||
});
|
||||
|
||||
ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf);
|
||||
@@ -391,24 +391,24 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
|
||||
}
|
||||
|
||||
// normalization
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
|
||||
[&](auto idx) {
|
||||
constexpr auto iM = Number<idx[Number<0>{}]>{};
|
||||
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
|
||||
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// normalize
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
|
||||
inv_std_thread_buf(iM);
|
||||
// normalize
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
|
||||
inv_std_thread_buf(iM);
|
||||
|
||||
// gamma & beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) *
|
||||
gamma_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
// gamma & beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) *
|
||||
gamma_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_beta_load.Run(beta_grid_desc_m_k,
|
||||
@@ -422,19 +422,19 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
|
||||
[&](auto idx) {
|
||||
constexpr auto iM = Number<idx[Number<0>{}]>{};
|
||||
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
|
||||
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) +
|
||||
beta_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
// beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) +
|
||||
beta_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_y_store.Run(thread_buffer_desc_m_k,
|
||||
@@ -460,14 +460,14 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
|
||||
x_thread_buf(i));
|
||||
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
x_square_thread_buf(i)(Number<offset_m_k>{}) =
|
||||
x_thread_buf(i)(Number<offset_m_k>{}) *
|
||||
x_thread_buf(i)(Number<offset_m_k>{});
|
||||
});
|
||||
static_ford<Sequence<MThreadSliceSize, XSrcVectorSize>>{}([&](auto ii) {
|
||||
constexpr auto iM = Number<ii[Number<0>{}]>{};
|
||||
constexpr auto iK = Number<ii[Number<1>{}]>{};
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
|
||||
x_square_thread_buf(i)(Number<offset_m_k>{}) =
|
||||
x_thread_buf(i)(Number<offset_m_k>{}) *
|
||||
x_thread_buf(i)(Number<offset_m_k>{});
|
||||
});
|
||||
|
||||
ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf);
|
||||
@@ -544,24 +544,24 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
|
||||
[&](auto idx) {
|
||||
constexpr auto iM = Number<idx[Number<0>{}]>{};
|
||||
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
|
||||
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// normalize
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
|
||||
inv_std_thread_buf(iM);
|
||||
// normalize
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
|
||||
inv_std_thread_buf(iM);
|
||||
|
||||
// gamma
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) *
|
||||
gamma_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
// gamma
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) *
|
||||
gamma_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_beta_load.Run(beta_grid_desc_m_k,
|
||||
@@ -573,19 +573,19 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
|
||||
[&](auto idx) {
|
||||
constexpr auto iM = Number<idx[Number<0>{}]>{};
|
||||
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
|
||||
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) +
|
||||
beta_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
// beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) +
|
||||
beta_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_y_store.Run(thread_buffer_desc_m_k,
|
||||
|
||||
@@ -441,24 +441,24 @@ struct GridwiseNormalizationSplitK2nd
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
|
||||
[&](auto idx) {
|
||||
constexpr auto iM = Number<idx[Number<0>{}]>{};
|
||||
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
|
||||
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// normalize
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
|
||||
inv_std_thread_buf(iM);
|
||||
// normalize
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
|
||||
inv_std_thread_buf(iM);
|
||||
|
||||
// gamma
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) *
|
||||
gamma_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
// gamma
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) *
|
||||
gamma_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_beta_load.Run(beta_grid_desc_m_k,
|
||||
@@ -470,19 +470,19 @@ struct GridwiseNormalizationSplitK2nd
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
|
||||
[&](auto idx) {
|
||||
constexpr auto iM = Number<idx[Number<0>{}]>{};
|
||||
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
|
||||
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) +
|
||||
beta_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
// beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) +
|
||||
beta_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_y_store.Run(thread_buffer_desc_m_k,
|
||||
|
||||
@@ -365,24 +365,24 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
}
|
||||
|
||||
// normalization
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
|
||||
[&](auto idx) {
|
||||
constexpr auto iM = Number<idx[Number<0>{}]>{};
|
||||
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
|
||||
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// normalize
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
|
||||
inv_std_thread_buf(iM);
|
||||
// normalize
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
|
||||
inv_std_thread_buf(iM);
|
||||
|
||||
// gamma & beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) *
|
||||
gamma_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
// gamma & beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) *
|
||||
gamma_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_beta_load.Run(beta_grid_desc_m_k,
|
||||
@@ -396,19 +396,19 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
|
||||
[&](auto idx) {
|
||||
constexpr auto iM = Number<idx[Number<0>{}]>{};
|
||||
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
|
||||
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) +
|
||||
beta_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
// beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) +
|
||||
beta_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_y_store.Run(thread_buffer_desc_m_k,
|
||||
@@ -496,24 +496,24 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
|
||||
[&](auto idx) {
|
||||
constexpr auto iM = Number<idx[Number<0>{}]>{};
|
||||
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
|
||||
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// normalize
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
|
||||
inv_std_thread_buf(iM);
|
||||
// normalize
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
|
||||
inv_std_thread_buf(iM);
|
||||
|
||||
// gamma
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) *
|
||||
gamma_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
// gamma
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) *
|
||||
gamma_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_beta_load.Run(beta_grid_desc_m_k,
|
||||
@@ -525,19 +525,19 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
|
||||
thread_copy_fwd_step_m_k);
|
||||
});
|
||||
|
||||
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
|
||||
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
|
||||
[&](auto idx) {
|
||||
constexpr auto iM = Number<idx[Number<0>{}]>{};
|
||||
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
|
||||
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
|
||||
constexpr auto offset_m_k =
|
||||
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
|
||||
|
||||
// beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) +
|
||||
beta_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
// beta
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) =
|
||||
y_thread_buf(iK0)(Number<offset_m_k>{}) +
|
||||
beta_thread_buf(iK0)(Number<offset_m_k>{});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
|
||||
threadwise_y_store.Run(thread_buffer_desc_m_k,
|
||||
|
||||
Reference in New Issue
Block a user