[rocm-libraries] ROCm/rocm-libraries#5031 (commit 1d86a92)

[CK] Replace nested static_for with static_ford to reduce
 device IR function emissions [1B] (#5031)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

### Rationale
CK's GPU kernels are among the slowest files in the ROCm build, with a
single translation
unit taking up to 10+ minutes. Profiling with `-ftime-trace` identified
nested `static_for`
loops as the root cause: each nesting level multiplies the number of
unique lambda IR
functions the compiler must process. A 2-level nest of `static_for<0, M,
1>` /
`static_for<0, N, 1>` produces M×N unique lambda types. With typical
GEMM dimensions
(M=16, N=4), a single nest generates 64 unique functions — and these
nests appear hundreds
of times across the codebase.

The LLVM backend's CGSCC (Call Graph Strongly Connected Components)
framework processes
each function independently, so reducing function count directly reduces
backend time.

### What changed
393 nested compile-time loop patterns across 73 files are converted to
`static_ford`, which
flattens multi-dimensional compile-time iteration into a single
`static_for` with index
decomposition. This eliminates 994 `static_for` nesting levels (42%
reduction).

Three pattern categories were converted:
- **Category A**: `static_for` wrapping `static_ford` — fold outer
dimension into ford
- **Category B**: nested `static_ford` — merge into single
higher-dimensional ford
- **Category C**: nested `static_for` chains — convert to single
`static_ford`

### Verification

**ASM equivalence: PASS — 51/51 device assembly files identical (gfx942
+ gfx1100)**

| Architecture | Files compared | Largest file | Result |
|---|---|---|---|
| gfx942 | 36 | 386,685 lines | ALL MATCH |
| gfx1100 | 15 | 47,769 lines | ALL MATCH |

**Build time (Wilcoxon signed-rank test, 7 paired trials):**

| Target | Pre (s) | Post (s) | Delta | p-value |
|---|---|---|---|---|
| bscale | 169 | 152 | **-9.8%** | 0.016 \* |
| xdl_v1234 | 207 | 194 | **-6.6%** | 0.016 \* |
| preshuffle | 275 | 264 | **-3.9%** | 0.016 \* |
| xdl_base | 142 | 137 | **-3.2%** | 0.031 \* |

**IR function counts (device backend, gfx942):**

| Target | InstFunc Δ | CodeGen Δ | Compiler Δ |
|---|---|---|---|
| bscale | -13,043 (-8.2%) | -2,103 (-3.5%) | -10.7% |
| xdl_v1234 | -9,431 (-5.7%) | +59 (+0.1%) | -5.2% |
| xdl_base | -6,162 (-4.9%) | -1,141 (-2.5%) | -2.2% |
| xdl_old | -3,234 (-3.7%) | -963 (-8.7%) | -3.3% |

### Value
- **994 fewer `static_for` nesting levels** (-42%) across 73 files
- **393 `static_ford` sites** created (from 4 pre-existing)
- **Up to 9.8% compile-time reduction** on representative targets
(statistically significant, p < 0.05)
- **Up to 13K fewer IR function instantiations** per translation unit
- Net -849 LOC from reduced indentation
- **Zero ASM changes** — identical device code output verified on gfx942
and gfx1100
- All scheduling barriers, `if constexpr` guards, and MFMA/WMMA
accumulation order preserved

### Files changed (73)
- `block/`: 47 files (GEMM pipelines — xdlops, wmma, moe, preshuffle,
blockscale variants)
- `grid/`: 20 files (softmax, normalization, reduction, attention,
layernorm)
- `thread/`: 5 files (tensor slice transfer, contraction, GEMM dlops,
reduction)
- `tensor_description/`: 1 file (tensor_adaptor)

## Test plan
- [x] `static_ford` tested with 21 unit tests in
`test/util/unit_ford.cpp`
  (1D-4D, custom orders, compile-time verification)
- [x] All conversions preserve iteration order, `block_sync_lds()`
placement,
  `if constexpr` scheduling guards, and MFMA/WMMA accumulation order
- [x] ASM equivalence verified: 51 device `.s` files across gfx942 +
gfx1100
- [x] Build-time improvement statistically confirmed (Wilcoxon, p <
0.05, 4 targets)
- [x] IR function count reduction confirmed via `-ftime-trace` on 7
targets
- [x] Detection script reports 0 remaining safe patterns (180 blocked
with structural reasons)
- [x] Existing CI tests (GEMM, softmax, normalization, batch norm,
reduction,
  attention) exercise all converted code paths

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Christopher Millette
2026-03-18 14:46:50 +00:00
committed by assistant-librarian[bot]
parent 5f90f69795
commit e5683e2290
73 changed files with 9398 additions and 10247 deletions

View File

@@ -480,19 +480,17 @@ struct GridwiseWelfordSecondHalfReduceFirstHalf
make_tuple(I0, I0),
dy_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iK = Number<ii[Number<1>{}]>{};
constexpr auto offset = thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
dy_elementwise_op(dy_thread_buf(Number<offset>{}), dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
AccDataType norm_x =
(x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) * inv_var_thread_buf[iM];
tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
});
tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
});
ThreadwiseReduce::Reduce(tmp1_thread_buf, reduce_dscale_thread_buf);

View File

@@ -502,15 +502,15 @@ struct EpilogueReduceCShuffle
[&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
// reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
static_ford<Sequence<mreduce_per_thread, nreduce_per_thread>>{}([&](auto ii) {
constexpr auto im = Number<ii[Number<0>{}]>{};
constexpr auto in = Number<ii[Number<1>{}]>{};
constexpr auto offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
reduce_in_element_op(c_reduce_thread_buf(offset),
c_reduce_thread_buf(offset));
});
reduce_in_element_op(c_reduce_thread_buf(offset),
c_reduce_thread_buf(offset));
});
ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);

View File

@@ -362,11 +362,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d
make_tuple(I0),
gamma_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) * gamma_thread_buf(n);
});
static_ford<Sequence<MThreadSliceSize, NThreadSliceSize>>{}([&](auto mn) {
constexpr auto m = Number<mn[Number<0>{}]>{};
constexpr auto n = Number<mn[Number<1>{}]>{};
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) * gamma_thread_buf(n);
});
threadwise_beta_load_n.Run(beta_grid_desc_n,
@@ -375,11 +375,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d
make_tuple(I0),
beta_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto m) {
static_for<0, NThreadSliceSize, 1>{}([&](auto n) {
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) + beta_thread_buf(n);
});
static_ford<Sequence<MThreadSliceSize, NThreadSliceSize>>{}([&](auto mn) {
constexpr auto m = Number<mn[Number<0>{}]>{};
constexpr auto n = Number<mn[Number<1>{}]>{};
constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n));
h_thread_buf(Number<m_n>{}) = h_thread_buf(Number<m_n>{}) + beta_thread_buf(n);
});
threadwise_h_store_m_n.Run(thread_buffer_desc_m_n,

View File

@@ -218,14 +218,12 @@ struct GridwiseMultipleReduction_mk_to_m_multiblock
in_thread_buf);
static_for<0, NumReduction, 1>{}([&](auto iR) {
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iK = Number<ii[Number<1>{}]>{};
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR));

View File

@@ -173,14 +173,12 @@ struct GridwiseMultipleReduction_mk_to_m_threadwise
in_thread_buf);
static_for<0, NumReduction, 1>{}([&](auto iR) {
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iK = Number<ii[Number<1>{}]>{};
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR));

View File

@@ -212,13 +212,11 @@ struct GridwiseReduction_mk_to_m_multiblock
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iK = Number<ii[Number<1>{}]>{};
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}), in_thread_buf(Number<offset>{}));
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);

View File

@@ -162,13 +162,11 @@ struct GridwiseReduction_mk_to_m_threadwise
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iK = Number<ii[Number<1>{}]>{};
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}), in_thread_buf(Number<offset>{}));
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
@@ -340,15 +338,13 @@ struct GridwiseReduction_mk_to_m_threadwise
make_tuple(I0, I0),
in_thread_idx_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iK = Number<ii[Number<1>{}]>{};
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
ThreadwiseReduceWithIndex::Reduce(
@@ -371,17 +367,15 @@ struct GridwiseReduction_mk_to_m_threadwise
make_tuple(I0, I0),
in_thread_val_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iK = Number<ii[Number<1>{}]>{};
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
in_elementwise_op(in_thread_val_buf(Number<offset>{}),
in_thread_val_buf(Number<offset>{}));
});
ThreadwiseReduceWithIndex::Reduce(

View File

@@ -160,13 +160,11 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d
make_tuple(I0, I0),
in_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// do element-wise pre-reduction operation
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}),
in_thread_buf(Number<offset>{}));
});
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iK = Number<ii[Number<1>{}]>{};
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
in_elementwise_op(in_thread_buf(Number<offset>{}), in_thread_buf(Number<offset>{}));
});
ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);

View File

@@ -1297,15 +1297,15 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
Acc1DataType acc1 = acc1_thread_buf[I]; // P*V
Acc1DataType c = c_thread_buf[I]; // O
Acc1DataType c_new = c + acc1; // Simply add results since we are no longer using softmax.
static_ford<Sequence<c_thread_buf_slice_m, c_thread_buf_slice_n>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iN = Number<ii[Number<1>{}]>{};
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
Acc1DataType acc1 = acc1_thread_buf[I]; // P*V
Acc1DataType c = c_thread_buf[I]; // O
Acc1DataType c_new = c + acc1; // Simply add results since we are no longer using softmax.
c_thread_buf(I) = c_new; // O_new
});
c_thread_buf(I) = c_new; // O_new
});
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc,

View File

@@ -1234,19 +1234,19 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
FloatGemmAcc c = c_thread_buf[I]; // O
FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
static_ford<Sequence<c_thread_buf_slice_m, c_thread_buf_slice_n>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iN = Number<ii[Number<1>{}]>{};
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
FloatGemmAcc c = c_thread_buf[I]; // O
FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new; // O_new
});
c_thread_buf(I) = c_new; // O_new
});
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1,

View File

@@ -1410,18 +1410,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
Acc1DataType acc1 = acc1_thread_buf[I]; // P*V
Acc1DataType c = c_thread_buf[I]; // O
Acc1DataType c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM];
static_ford<Sequence<c_thread_buf_slice_m, c_thread_buf_slice_n>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iN = Number<ii[Number<1>{}]>{};
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
Acc1DataType acc1 = acc1_thread_buf[I]; // P*V
Acc1DataType c = c_thread_buf[I]; // O
Acc1DataType c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM];
c_thread_buf(I) = c_new; // O_new
});
c_thread_buf(I) = c_new; // O_new
});
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc,

View File

@@ -1048,19 +1048,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) {
static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) {
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
FloatGemmAcc c = c_thread_buf[I]; // O
FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
static_ford<Sequence<c_thread_buf_slice_m, c_thread_buf_slice_n>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iN = Number<ii[Number<1>{}]>{};
auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
FloatGemmAcc c = c_thread_buf[I]; // O
FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new; // O_new
});
c_thread_buf(I) = c_new; // O_new
});
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1,

View File

@@ -437,19 +437,17 @@ struct GridwiseBatchNormBackwardWithBlockwiseWelford
make_tuple(I0, I0),
dy_thread_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iK = Number<ii[Number<1>{}]>{};
constexpr auto offset = thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
dy_elementwise_op(dy_thread_buf(Number<offset>{}),
dy_thread_buf[Number<offset>{}]);
dy_elementwise_op(dy_thread_buf(Number<offset>{}), dy_thread_buf[Number<offset>{}]);
AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
inv_var_thread_buf[iM];
AccDataType norm_x =
(x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) * inv_var_thread_buf[iM];
tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
});
tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
});
ThreadwiseReduce::Reduce(tmp1_thread_buf, dscale_thread_buf);

View File

@@ -355,8 +355,10 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { // input add loop
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
static_ford<Sequence<MThreadSliceSize, XSrcVectorSize>>{}(
[&](auto mk) { // input add loop
constexpr auto iM = Number<mk[Number<0>{}]>{};
constexpr auto iK1 = Number<mk[Number<1>{}]>{};
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
@@ -376,7 +378,6 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
unpack2(x_elementwise_op, out_data_refs, in_data_refs);
});
});
threadwise_welford.Run(x_thread_buf[iK0], mean_thread_buf, var_thread_buf);
if constexpr(!SweepOnce)
@@ -435,21 +436,20 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
static_ford<Sequence<XThreadBufferNumber, XSrcVectorSize>>{}([&](auto ii) {
constexpr auto iK0 = Number<ii[Number<0>{}]>{};
constexpr auto iK1 = Number<ii[Number<1>{}]>{};
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
divisor;
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) * divisor;
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
});
@@ -463,19 +463,19 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
static_ford<Sequence<MThreadSliceSize, XThreadBufferNumber, XSrcVectorSize>>{}(
[&](auto mii) {
constexpr auto iM = Number<mii[Number<0>{}]>{};
constexpr auto iK0 = Number<mii[Number<1>{}]>{};
constexpr auto iK1 = Number<mii[Number<2>{}]>{};
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
});
static_for<0, YThreadBufferNumber, 1>{}([&](auto i) {
threadwise_y_store.Run(thread_buffer_desc_m_k,

View File

@@ -937,8 +937,10 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
[&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
// reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
static_ford<Sequence<mreduce_per_thread, nreduce_per_thread>>{}(
[&](auto ii) {
constexpr auto im = Number<ii[Number<0>{}]>{};
constexpr auto in = Number<ii[Number<1>{}]>{};
constexpr auto offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
@@ -946,7 +948,6 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
reduce_in_element_op(c_reduce_thread_buf(offset),
c_reduce_thread_buf(offset));
});
});
ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);

View File

@@ -881,14 +881,14 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ThreadReduceOperation::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { r_thread_buf(I) = reduce_identityVal; });
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto offset =
Number<cde_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
static_ford<Sequence<mreduce_per_thread, nreduce_per_thread>>{}([&](auto ii) {
constexpr auto im = Number<ii[Number<0>{}]>{};
constexpr auto in = Number<ii[Number<1>{}]>{};
constexpr auto offset =
Number<cde_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
qs_element_op[Ir](e_thread_buf(offset), e_thread_buf(offset));
});
qs_element_op[Ir](e_thread_buf(offset), e_thread_buf(offset));
});
ThreadwiseReduce::Reduce(e_thread_buf, r_thread_buf);

View File

@@ -820,8 +820,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
[&](auto I) { reduce_thread_buf(I) = reduce_identityVal; });
// reduce in VGPR
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
static_ford<Sequence<mreduce_per_thread, nreduce_per_thread>>{}(
[&](auto ii) {
constexpr auto im = Number<ii[Number<0>{}]>{};
constexpr auto in = Number<ii[Number<1>{}]>{};
constexpr auto offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
@@ -829,7 +831,6 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
reduce_in_element_op(c_reduce_thread_buf(offset),
c_reduce_thread_buf(offset));
});
});
ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf);

View File

@@ -997,26 +997,26 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0]
.GetUpperLengths()[I1]; // TODO: proper handle
static_for<0, mreduce_per_thread, 1>{}([&](auto im) {
static_for<0, nreduce_per_thread, 1>{}([&](auto in) {
constexpr auto dst_offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
static_ford<Sequence<mreduce_per_thread, nreduce_per_thread>>{}([&](auto ii) {
constexpr auto im = Number<ii[Number<0>{}]>{};
constexpr auto in = Number<ii[Number<1>{}]>{};
constexpr auto dst_offset =
Number<c_reduce_thread_desc_mperblock_nperblock.CalculateOffset(
make_tuple(im, in))>{};
constexpr auto src_offset =
Number<d_reduce_thread_desc_mperblock.CalculateOffset(
make_tuple(im))>{};
constexpr auto src_offset =
Number<d_reduce_thread_desc_mperblock.CalculateOffset(
make_tuple(im))>{};
FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw;
FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw;
FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw;
FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw;
FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum;
FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum;
FloatReduceAcc divisor_sqrt;
tensor_operation::element_wise::UnarySqrt{}(divisor_sqrt, divisor);
FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum;
FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum;
FloatReduceAcc divisor_sqrt;
tensor_operation::element_wise::UnarySqrt{}(divisor_sqrt, divisor);
c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt;
});
c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt;
});
// scaling

View File

@@ -290,12 +290,12 @@ struct GridwiseSoftmax_mk_to_mk
}
// do element-wise pre-reduction operation
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
out_thread_buf(Number<offset>{}) =
math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM));
});
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iK = Number<ii[Number<1>{}]>{};
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
out_thread_buf(Number<offset>{}) =
math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM));
});
ThreadwiseSumReduce::Reduce(out_thread_buf, accu_value_buf);
@@ -330,15 +330,13 @@ struct GridwiseSoftmax_mk_to_mk
in_thread_buf);
}
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
out_thread_buf(Number<offset>{}) =
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
accu_value_buf(iM);
});
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iK = Number<ii[Number<1>{}]>{};
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
out_thread_buf(Number<offset>{}) =
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
accu_value_buf(iM);
});
threadwise_dst_store.Run(thread_buffer_desc,
@@ -376,16 +374,14 @@ struct GridwiseSoftmax_mk_to_mk
make_tuple(I0, I0),
in_prior_dst_buf);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
// out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
constexpr auto offset =
thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
out_thread_buf(Number<offset>{}) =
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
accu_value_buf(iM) +
beta * in_prior_dst_buf(Number<offset>{});
});
static_ford<Sequence<MThreadSliceSize, KThreadSliceSize>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iK = Number<ii[Number<1>{}]>{};
constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
out_thread_buf(Number<offset>{}) =
alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
accu_value_buf(iM) +
beta * in_prior_dst_buf(Number<offset>{});
});
threadwise_dst_store.Run(thread_buffer_desc,

View File

@@ -175,36 +175,36 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
};
auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
auto in_data_refs = generate_tie(
[&](auto i_embedding_) -> const auto& {
return in_thread_bufs(i_embedding_)(Number<register_offset>{});
},
Number<NumEmbeddings>{});
auto out_data_refs = generate_tie(
[&](auto) -> auto& { return acc_thread_buf(Number<register_offset>{}); },
Number<1>{});
unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
});
static_ford<Sequence<DimThreadSize, RowVectorSize>>{}([&](auto ii) {
constexpr auto i_dim_vec_ = Number<ii[Number<0>{}]>{};
constexpr auto i_row_vec_ = Number<ii[Number<1>{}]>{};
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
auto in_data_refs = generate_tie(
[&](auto i_embedding_) -> const auto& {
return in_thread_bufs(i_embedding_)(Number<register_offset>{});
},
Number<NumEmbeddings>{});
auto out_data_refs = generate_tie(
[&](auto) -> auto& { return acc_thread_buf(Number<register_offset>{}); },
Number<1>{});
unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
});
};
auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
constexpr auto mean_var_offset =
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
static_ford<Sequence<DimThreadSize, RowVectorSize>>{}([&](auto ii) {
constexpr auto i_dim_vec_ = Number<ii[Number<0>{}]>{};
constexpr auto i_row_vec_ = Number<ii[Number<1>{}]>{};
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
constexpr auto mean_var_offset =
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
threadwise_welford.cur_count_++;
threadwise_welford.Update(mean_thread_buf(Number<mean_var_offset>{}),
var_thread_buf(Number<mean_var_offset>{}),
acc_thread_buf(Number<register_offset>{}));
});
threadwise_welford.cur_count_++;
threadwise_welford.Update(mean_thread_buf(Number<mean_var_offset>{}),
var_thread_buf(Number<mean_var_offset>{}),
acc_thread_buf(Number<register_offset>{}));
});
};
@@ -246,12 +246,11 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
};
// first load index
ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) {
ck::static_ford<Sequence<DimPerBlock, NumEmbeddings>>{}([&](auto ie) {
constexpr auto i_idx_ = Number<ie[Number<0>{}]>{};
constexpr auto i_embedding_ = Number<ie[Number<1>{}]>{};
// prefer use s_load
ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
index_bufs(i_embedding_)(i_idx_) =
p_indexes[i_embedding_][index_start + i_idx_.value];
});
index_bufs(i_embedding_)(i_idx_) = p_indexes[i_embedding_][index_start + i_idx_.value];
});
// load gamma/beta

View File

@@ -176,36 +176,36 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
};
auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
auto in_data_refs = generate_tie(
[&](auto i_embedding_) -> const auto& {
return in_thread_bufs(i_embedding_)(Number<register_offset>{});
},
Number<NumEmbeddings>{});
auto out_data_refs = generate_tie(
[&](auto) -> auto& { return acc_thread_buf(Number<register_offset>{}); },
Number<1>{});
unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
});
static_ford<Sequence<DimThreadSize, RowVectorSize>>{}([&](auto ii) {
constexpr auto i_dim_vec_ = Number<ii[Number<0>{}]>{};
constexpr auto i_row_vec_ = Number<ii[Number<1>{}]>{};
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
auto in_data_refs = generate_tie(
[&](auto i_embedding_) -> const auto& {
return in_thread_bufs(i_embedding_)(Number<register_offset>{});
},
Number<NumEmbeddings>{});
auto out_data_refs = generate_tie(
[&](auto) -> auto& { return acc_thread_buf(Number<register_offset>{}); },
Number<1>{});
unpack2(emb_elementwise_op, out_data_refs, in_data_refs);
});
};
auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) {
static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) {
static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) {
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
constexpr auto mean_var_offset =
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
static_ford<Sequence<DimThreadSize, RowVectorSize>>{}([&](auto ii) {
constexpr auto i_dim_vec_ = Number<ii[Number<0>{}]>{};
constexpr auto i_row_vec_ = Number<ii[Number<1>{}]>{};
constexpr auto register_offset = thread_buf_desc.CalculateOffset(
make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_));
constexpr auto mean_var_offset =
mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_));
threadwise_welford.cur_count_++;
threadwise_welford.Update(mean_thread_buf(Number<mean_var_offset>{}),
var_thread_buf(Number<mean_var_offset>{}),
acc_thread_buf(Number<register_offset>{}));
});
threadwise_welford.cur_count_++;
threadwise_welford.Update(mean_thread_buf(Number<mean_var_offset>{}),
var_thread_buf(Number<mean_var_offset>{}),
acc_thread_buf(Number<register_offset>{}));
});
};
@@ -247,12 +247,11 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
};
// first load index
ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) {
ck::static_ford<Sequence<DimPerBlock, NumEmbeddings>>{}([&](auto ie) {
constexpr auto i_idx_ = Number<ie[Number<0>{}]>{};
constexpr auto i_embedding_ = Number<ie[Number<1>{}]>{};
// prefer use s_load
ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) {
index_bufs(i_embedding_)(i_idx_) =
p_indexes[i_embedding_][index_start + i_idx_.value];
});
index_bufs(i_embedding_)(i_idx_) = p_indexes[i_embedding_][index_start + i_idx_.value];
});
// load gamma/beta

View File

@@ -328,14 +328,14 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
make_tuple(I0, I0),
gamma_thread_buf(i));
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
x_square_thread_buf(i)(Number<offset_m_k>{}) =
x_thread_buf(i)(Number<offset_m_k>{}) *
x_thread_buf(i)(Number<offset_m_k>{});
});
static_ford<Sequence<MThreadSliceSize, XSrcVectorSize>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iK = Number<ii[Number<1>{}]>{};
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
x_square_thread_buf(i)(Number<offset_m_k>{}) =
x_thread_buf(i)(Number<offset_m_k>{}) *
x_thread_buf(i)(Number<offset_m_k>{});
});
ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf);
@@ -391,24 +391,24 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
}
// normalization
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
[&](auto idx) {
constexpr auto iM = Number<idx[Number<0>{}]>{};
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
inv_std_thread_buf(iM);
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
inv_std_thread_buf(iM);
// gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
// gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
@@ -422,19 +422,19 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
[&](auto idx) {
constexpr auto iM = Number<idx[Number<0>{}]>{};
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_y_store.Run(thread_buffer_desc_m_k,
@@ -460,14 +460,14 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
x_thread_buf(i));
threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
x_square_thread_buf(i)(Number<offset_m_k>{}) =
x_thread_buf(i)(Number<offset_m_k>{}) *
x_thread_buf(i)(Number<offset_m_k>{});
});
static_ford<Sequence<MThreadSliceSize, XSrcVectorSize>>{}([&](auto ii) {
constexpr auto iM = Number<ii[Number<0>{}]>{};
constexpr auto iK = Number<ii[Number<1>{}]>{};
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
x_square_thread_buf(i)(Number<offset_m_k>{}) =
x_thread_buf(i)(Number<offset_m_k>{}) *
x_thread_buf(i)(Number<offset_m_k>{});
});
ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf);
@@ -544,24 +544,24 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
[&](auto idx) {
constexpr auto iM = Number<idx[Number<0>{}]>{};
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
inv_std_thread_buf(iM);
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
inv_std_thread_buf(iM);
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
@@ -573,19 +573,19 @@ struct GridwiseNormalizationNaiveVariance_mk_to_mk
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
[&](auto idx) {
constexpr auto iM = Number<idx[Number<0>{}]>{};
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_y_store.Run(thread_buffer_desc_m_k,

View File

@@ -441,24 +441,24 @@ struct GridwiseNormalizationSplitK2nd
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
[&](auto idx) {
constexpr auto iM = Number<idx[Number<0>{}]>{};
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
inv_std_thread_buf(iM);
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
inv_std_thread_buf(iM);
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
@@ -470,19 +470,19 @@ struct GridwiseNormalizationSplitK2nd
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
[&](auto idx) {
constexpr auto iM = Number<idx[Number<0>{}]>{};
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_y_store.Run(thread_buffer_desc_m_k,

View File

@@ -365,24 +365,24 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
}
// normalization
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
[&](auto idx) {
constexpr auto iM = Number<idx[Number<0>{}]>{};
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
inv_std_thread_buf(iM);
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
inv_std_thread_buf(iM);
// gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
// gamma & beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
@@ -396,19 +396,19 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
[&](auto idx) {
constexpr auto iM = Number<idx[Number<0>{}]>{};
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_y_store.Run(thread_buffer_desc_m_k,
@@ -496,24 +496,24 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
[&](auto idx) {
constexpr auto iM = Number<idx[Number<0>{}]>{};
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
inv_std_thread_buf(iM);
// normalize
y_thread_buf(iK0)(Number<offset_m_k>{}) =
(x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
inv_std_thread_buf(iM);
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
// gamma
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) *
gamma_thread_buf(iK0)(Number<offset_m_k>{});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_beta_load.Run(beta_grid_desc_m_k,
@@ -525,19 +525,19 @@ struct GridwiseNormalizationWelfordVariance_mk_to_mk
thread_copy_fwd_step_m_k);
});
static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) {
static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
static_ford<Sequence<MThreadSliceSize, ThreadBufferNumber, XSrcVectorSize>>{}(
[&](auto idx) {
constexpr auto iM = Number<idx[Number<0>{}]>{};
constexpr auto iK0 = Number<idx[Number<1>{}]>{};
constexpr auto iK1 = Number<idx[Number<2>{}]>{};
constexpr auto offset_m_k =
thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
// beta
y_thread_buf(iK0)(Number<offset_m_k>{}) =
y_thread_buf(iK0)(Number<offset_m_k>{}) +
beta_thread_buf(iK0)(Number<offset_m_k>{});
});
});
static_for<0, ThreadBufferNumber, 1>{}([&](auto i) {
threadwise_y_store.Run(thread_buffer_desc_m_k,