mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
[CK_TILE] Flatten nested static_for loops into static_ford (#5939)
## Summary
Mechanical conversion of 129 nested `static_for`/`static_ford` patterns
to flat `static_ford` across 29 ck_tile header files.
Each conversion eliminates intermediate lambda closure instantiations by
replacing nested compile-time loops with a single flat iteration using
index decomposition.
### What `static_ford` eliminates
When `static_for` loops are nested, each level creates unique closure
types:
```cpp
// BEFORE: M + M×N = 20 IR functions (for M=4, N=4)
static_for<0, 4, 1>{}([&](auto m) { // 4 closure instantiations
static_for<0, 4, 1>{}([&](auto n) { // 4×4 = 16 closure instantiations
body(m, n);
});
});
// AFTER: M×N = 16 IR functions (with ford_applier, no intermediates)
static_ford<sequence<4, 4>>{}([&](auto mn) {
constexpr auto m = number<mn[number<0>{}]>{};
constexpr auto n = number<mn[number<1>{}]>{};
body(m, n);
});
```
### Pattern categories converted
| Category | Count | Description |
|----------|-------|-------------|
| C (2-level `static_for` chains) | 112 | Nested `static_for` →
`static_ford` |
| C3 (3-level `static_for` chains) | 9 | Three consecutive nests →
`static_ford` |
| Partial rescue | 3 | Outer 2 levels of blocked 4-level nests |
| B (nested `static_ford` merge) | 5 | Two nested `static_ford` → single
higher-dim `static_ford` |
| **Total** | **129** | Across 29 files |
6 false positives were detected and reverted (in `tensor_adaptor.hpp`,
`tile_distribution.hpp`, `tile_distribution_encoding.hpp`) where the
inner loop bound depended on the outer variable.
### Files changed by family
| Family | Files | Sites |
|--------|-------|-------|
| Block GEMM | 12 | ~20 |
| FlatMM pipelines | 4 | ~69 (including 5 ford-ford merges) |
| GEMM quant | 7 | ~22 |
| FlatMM kernel | 1 | 2 |
| FMHA | 1 | 2 |
| Reduce/norm | 2 | 2 |
| Epilogue | 1 | 1 |
### Blocked locations from review comments
- **block_gemm_areg_breg_creg_v1.hpp:356** — BLOCKED: runtime scale
loads (`scale_a_slice`, `scale_b_slice`, A warp tensor load) between
every nesting level
- **block_universal_gemm_ar_aquant_flatbr_bquant_cr.hpp:228** — BLOCKED:
`zero_accumulators()` before inner loop; `sched_barrier` + conditional
`block_sync_lds()` after inner loop
- **block_universal_gemm_as_aquant_bs_bquant_cr.hpp:298** — BLOCKED:
runtime `CWarpTensor` construction before inner loop; quantization scale
application code after inner loop
- **block_universal_gemm_as_aquant_bs_cr.hpp:277** — BLOCKED: same
pattern as above
- **block_universal_gemm_as_bs_bquant_cr.hpp:367** — BLOCKED: same
pattern as above
## Depends on
- #5938 ([CK_TILE] Optimize static_ford and sequence compile-time
infrastructure) — provides the `ford_applier` that makes these
conversions beneficial. Without it, `static_ford` uses a recursive
implementation that provides no IR function savings.
## Results (combined with #5938)
### Build Time (Wilcoxon signed-rank, 7 paired trials, gfx942)
| Target | Base (s) | Treat (s) | Delta | % | Significant? |
|--------|----------|-----------|-------|---|-------------|
| **flatmm** | 161.1 | 149.0 | **-12.1s** | **-7.5%** | **YES** (p<0.01,
7/7 wins) |
| **universal_gemm** | 225.4 | 220.3 | **-5.1s** | **-2.3%** | **YES**
(p<0.01, 7/7 wins) |
### IR Function Counts (device trace, gfx942)
| Target | InstFunc | CodeGen |
|--------|----------|---------|
| universal_gemm | **-8.5%** | **-9.2%** |
| flatmm | **-7.6%** | **-10.5%** |
### ASM Equivalence
5/5 PASS — 650,151 lines verified identical (gfx942). TUs:
universal_gemm, flatmm_basic, fmha_bwd, reduce, bscale.
## Test plan
- [x] ASM equivalence verified (650K lines, gfx942)
- [x] Wilcoxon timing verified (7 trials, p<0.01)
- [x] IR function counts verified (-7.6% to -10.5% CodeGen reduction)
- [ ] CI
🤖 Generated with [Claude Code](https://claude.com/claude-code)
---------
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
341fb33386
commit
7816812ef8
@@ -92,29 +92,29 @@ struct BlockFlatmmASmemBSmemCRegV1
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block window
|
||||
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
// read A warp tensor from A block window
|
||||
const auto a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
__builtin_amdgcn_sched_barrier(0x7F6);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1105,15 +1105,14 @@ struct MoeFlatmmKernel
|
||||
statically_indexed_array<index_t, ScaleMRepeat> scale_m_offsets;
|
||||
|
||||
if constexpr(!BMXFP4_Pipeline)
|
||||
static_for<0, MRepeat, 1>{}([&](auto mIter) {
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
const auto row_idx =
|
||||
coord_m + mIter * MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[I0];
|
||||
scale_m_offsets[mIter * number<kM0 * kM2>{} + m0 * number<kM2>{} + m2] =
|
||||
row_to_token_idx(row_idx);
|
||||
});
|
||||
});
|
||||
static_ford<sequence<MRepeat, kM0, kM2>>{}([&](auto mmm) {
|
||||
constexpr auto mIter = number<mmm[number<0>{}]>{};
|
||||
constexpr auto m0 = number<mmm[number<1>{}]>{};
|
||||
constexpr auto m2 = number<mmm[number<2>{}]>{};
|
||||
const auto row_idx =
|
||||
coord_m + mIter * MPerXdl + m0 * kM1 * kM2 + m2 + scale_m_coord[I0];
|
||||
scale_m_offsets[mIter * number<kM0 * kM2>{} + m0 * number<kM2>{} + m2] =
|
||||
row_to_token_idx(row_idx);
|
||||
});
|
||||
|
||||
constexpr int DynamicTileOffsetFlag = 0;
|
||||
@@ -1426,19 +1425,19 @@ struct MoeFlatmmKernel
|
||||
statically_indexed_array<statically_indexed_array<bool, MPerThread>, NumMEpiTile>
|
||||
c_scatter_valids;
|
||||
auto c_coord = dram_tile_distribution.calculate_index();
|
||||
static_for<0, NumMEpiTile, 1>{}([&](auto mIter) {
|
||||
static_for<0, MPerThread, 1>{}([&](auto m0) {
|
||||
auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0;
|
||||
auto fused_token =
|
||||
kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0]
|
||||
static_ford<sequence<NumMEpiTile, MPerThread>>{}([&](auto mm) {
|
||||
constexpr auto mIter = number<mm[number<0>{}]>{};
|
||||
constexpr auto m0 = number<mm[number<1>{}]>{};
|
||||
auto row_idx = coord_m + mIter * MPerIterationShuffle + c_coord[0] + m0;
|
||||
auto fused_token =
|
||||
kargs.p_sorted_token_ids[row_idx]; // topk-idx[31:24] + token_idx[23:0]
|
||||
|
||||
index_t scatter_token_id = fused_token & token_id_mask;
|
||||
c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
|
||||
if constexpr(IsInputGemm)
|
||||
scatter_token_id =
|
||||
scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
|
||||
c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
|
||||
});
|
||||
index_t scatter_token_id = fused_token & token_id_mask;
|
||||
c_scatter_valids[mIter][m0] = (scatter_token_id < kargs.NumTokens);
|
||||
if constexpr(IsInputGemm)
|
||||
scatter_token_id =
|
||||
scatter_token_id * kargs.TopK + (fused_token >> token_id_offset);
|
||||
c_scatter_offsets[mIter][m0] = scatter_token_id * kargs.stride_C;
|
||||
});
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -606,16 +606,16 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
|
||||
constexpr auto mIter = number<mk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<mk[number<1>{}]>{};
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
|
||||
// Block GEMM
|
||||
@@ -656,15 +656,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
|
||||
constexpr auto nIter = number<nk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<nk[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
@@ -701,15 +701,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
|
||||
// Prefill A(2i+1)
|
||||
@@ -722,44 +722,44 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
@@ -776,15 +776,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
// Next K
|
||||
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
|
||||
// Prefill A(2i+2)
|
||||
@@ -797,43 +797,43 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i+1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
@@ -854,15 +854,15 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
|
||||
// Prefill A(loopK)
|
||||
@@ -870,44 +870,44 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
@@ -920,86 +920,86 @@ defined(USING_MFMA_32x32x64) && defined(ENABLE_FP4) // mi350 fp4 32c 1*K1
|
||||
Last2ndHotLoopScheduler();
|
||||
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -529,22 +529,22 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
MIterPerWarp>
|
||||
a_warp_windows_pong;
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
|
||||
constexpr auto mIter = number<mk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<mk[number<1>{}]>{};
|
||||
a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(a_warp_windows_ping(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
static_ford<sequence<MIterPerWarp, KIterPerWarp>>{}([&](auto mk) {
|
||||
constexpr auto mIter = number<mk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<mk[number<1>{}]>{};
|
||||
a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
|
||||
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
move_tile_window(a_warp_windows_pong(mIter)(kIter),
|
||||
{mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
|
||||
// Block GEMM
|
||||
@@ -592,26 +592,26 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
2;
|
||||
|
||||
// prefetch B
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<NIterPerWarp, KIterPerWarp>>{}([&](auto nk) {
|
||||
constexpr auto nIter = number<nk[number<0>{}]>{};
|
||||
constexpr auto kIter = number<nk[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
{
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
{
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
// move B window to next flat K
|
||||
move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
@@ -648,28 +648,27 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
if constexpr(!IsGateUpMode)
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
{
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
{
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
|
||||
// Prefill A(2i+1)
|
||||
@@ -682,44 +681,44 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
@@ -736,28 +735,27 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
// Next K
|
||||
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
if constexpr(!IsGateUpMode)
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
{
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
{
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
|
||||
// Prefill A(2i+2)
|
||||
@@ -770,43 +768,43 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
move_tile_window(a_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// GEMM 2i+1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
|
||||
// move B window to next flat K
|
||||
@@ -827,28 +825,27 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
if constexpr(!IsGateUpMode)
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
{
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
{
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
|
||||
// Prefill A(loopK)
|
||||
@@ -856,44 +853,44 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
|
||||
|
||||
// GEMM loopK-1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, m_preload, 1>{}([&](auto loadIter) {
|
||||
@@ -906,86 +903,86 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
Last2ndHotLoopScheduler();
|
||||
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_pong(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
// GEMM loopK
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
static_ford<sequence<KIterPerWarp, MIterPerWarp>>{}([&](auto km) {
|
||||
constexpr auto kIter = number<km[number<0>{}]>{};
|
||||
constexpr auto mIter = number<km[number<1>{}]>{};
|
||||
constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor,
|
||||
a_warp_tensor(number<AwarpIter>{}),
|
||||
b_warp_tensor_ping(nIter)(kIter));
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
// preload next A from lds
|
||||
if constexpr((kIter * MIterPerWarp + mIter) <
|
||||
(KIterPerWarp * MIterPerWarp - m_preload))
|
||||
{
|
||||
constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
|
||||
constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
|
||||
a_warp_tensor(number<AwarpIter>{}) =
|
||||
load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
|
||||
}
|
||||
|
||||
// barrier
|
||||
if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
|
||||
{
|
||||
block_sync_lds();
|
||||
}
|
||||
});
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
|
||||
@@ -486,13 +486,13 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
auto c_block_tile = BlockFlatmm{}.MakeCBlockTile();
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensors(mIter)(nIter).get_thread_buffer());
|
||||
});
|
||||
static_ford<sequence<MIterPerWarp, NIterPerWarp>>{}([&](auto mn) {
|
||||
constexpr auto mIter = number<mn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<mn[number<1>{}]>{};
|
||||
c_block_tile.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensors(mIter)(nIter).get_thread_buffer());
|
||||
});
|
||||
return c_block_tile;
|
||||
}
|
||||
@@ -643,24 +643,23 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
});
|
||||
|
||||
// prefetch Scale A
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
static_ford<sequence<MPackIterPerWarp, KPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto impack = number<ii[number<0>{}]>{};
|
||||
constexpr auto ikpack = number<ii[number<1>{}]>{};
|
||||
scale_a_tile_tensor_ping(impack)(ikpack) =
|
||||
load_tile_with_offset(scale_a_dram_window,
|
||||
|
||||
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
|
||||
});
|
||||
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
|
||||
});
|
||||
// move Scale A window to next K
|
||||
move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
|
||||
// prefetch Scale B
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
|
||||
});
|
||||
static_ford<sequence<NPackIterPerWarp, KPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto inpack = number<ii[number<0>{}]>{};
|
||||
constexpr auto ikpack = number<ii[number<1>{}]>{};
|
||||
scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset(
|
||||
scale_b_dram_window, inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
|
||||
});
|
||||
// move Scale B window to next K
|
||||
move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)});
|
||||
@@ -698,34 +697,34 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
// MAIN LOOP
|
||||
auto main_body_implx2 = [&]() mutable {
|
||||
// prefetch B(2i+1)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
|
||||
// move B window to next flat K
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
|
||||
});
|
||||
// move B window to next flat K
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
|
||||
});
|
||||
|
||||
// prefetch Scale A and Scale B (2i+1)
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) {
|
||||
scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
|
||||
});
|
||||
static_ford<sequence<KPackIterPerWarp, MPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto ikpack = number<ii[number<0>{}]>{};
|
||||
constexpr auto impack = number<ii[number<1>{}]>{};
|
||||
scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
|
||||
});
|
||||
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
|
||||
scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
|
||||
});
|
||||
static_ford<sequence<KPackIterPerWarp, NPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto ikpack = number<ii[number<0>{}]>{};
|
||||
constexpr auto inpack = number<ii[number<1>{}]>{};
|
||||
scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
|
||||
});
|
||||
|
||||
// GEMM 2i
|
||||
@@ -788,34 +787,34 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
////////////////////////////// Next K //////////////////////////////
|
||||
|
||||
// prefetch B(2i+2)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
|
||||
// move B window to next flat K
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
|
||||
});
|
||||
// move B window to next flat K
|
||||
if constexpr(kIter == KIterPerWarp - 1)
|
||||
b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset(
|
||||
tuple<number<0>, number<KIterPerWarp * KFlatBytesPerBlockPerIter>>{});
|
||||
});
|
||||
|
||||
// prefetch Scale A and Scale B (2i+2)
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) {
|
||||
scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
|
||||
});
|
||||
static_ford<sequence<KPackIterPerWarp, MPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto ikpack = number<ii[number<0>{}]>{};
|
||||
constexpr auto impack = number<ii[number<1>{}]>{};
|
||||
scale_a_tile_tensor_ping(impack)(ikpack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
|
||||
});
|
||||
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
|
||||
scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
|
||||
});
|
||||
static_ford<sequence<KPackIterPerWarp, NPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto ikpack = number<ii[number<0>{}]>{};
|
||||
constexpr auto inpack = number<ii[number<1>{}]>{};
|
||||
scale_b_tile_tensor_ping(inpack)(ikpack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
|
||||
});
|
||||
|
||||
// GEMM 2i+1
|
||||
@@ -888,28 +887,28 @@ struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1<Problem
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
// prefetch B(loopK)
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
});
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset(
|
||||
b_flat_dram_window,
|
||||
b_flat_dram_offsets(nIter) + kIter * KFlatBytesPerBlockPerIter);
|
||||
});
|
||||
|
||||
// prefetch Scale A and Scale B (2i+1)
|
||||
static_for<0, MPackIterPerWarp, 1>{}([&](auto impack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
|
||||
});
|
||||
static_ford<sequence<MPackIterPerWarp, KPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto impack = number<ii[number<0>{}]>{};
|
||||
constexpr auto ikpack = number<ii[number<1>{}]>{};
|
||||
scale_a_tile_tensor_pong(impack)(ikpack) = load_tile_with_offset(
|
||||
scale_a_dram_window,
|
||||
impack * scale_a_dram_step_m + ikpack * scale_a_dram_step_k);
|
||||
});
|
||||
static_for<0, NPackIterPerWarp, 1>{}([&](auto inpack) {
|
||||
static_for<0, KPackIterPerWarp, 1>{}([&](auto ikpack) {
|
||||
scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
|
||||
});
|
||||
static_ford<sequence<NPackIterPerWarp, KPackIterPerWarp>>{}([&](auto ii) {
|
||||
constexpr auto inpack = number<ii[number<0>{}]>{};
|
||||
constexpr auto ikpack = number<ii[number<1>{}]>{};
|
||||
scale_b_tile_tensor_pong(inpack)(ikpack) = load_tile_with_offset(
|
||||
scale_b_dram_window,
|
||||
inpack * scale_b_dram_step_n + ikpack * scale_b_dram_step_k);
|
||||
});
|
||||
|
||||
// GEMM loopK-1
|
||||
|
||||
Reference in New Issue
Block a user